In [19]:
import gymnasium as gym
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Normal

import numpy as np
import random as rd
import math

# Profiler
import cProfile
import re

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x31b4d0e90>

In [20]:
class ValueFunction(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, critic_learning_rate):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.critic_learning_rate = critic_learning_rate
        
        self.critic = nn.Sequential(
            nn.Linear(self.in_features, self.hidden_features),
            nn.ReLU(),
            nn.Linear(self.hidden_features, self.hidden_features),
            nn.ReLU(),
            nn.Linear(self.hidden_features, 1),
            nn.ReLU()
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.critic_learning_rate)


    def forward(self, input):
        return self.critic(input)
    
class Policy(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, actor_learning_rate, std, device):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.actor_learning_rate = actor_learning_rate
        self.std = std
        self.device = device

        self.actor = nn.Sequential(
            nn.Linear(self.in_features, self.hidden_features),
            nn.ReLU(),
            nn.Linear(self.hidden_features, self.hidden_features),
            nn.ReLU(),
            nn.Linear(self.hidden_features, self.out_features),
            nn.Tanh()
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.actor_learning_rate)

    def forward(self, in_):

        if in_.dim() == 1:
            in_ = in_.unsqueeze(0)
        h = self.actor(in_)

        epsilon = torch.randn(h.size(0), h.size(1)).to(self.device)
        z = h + self.std * epsilon
        return z, h, self.std
    
    def get_log_probability(self, z, mu, std):

        coeff =  1 / (std*math.sqrt(2*math.pi))
        normal_dist = coeff * torch.exp(-0.5 * (((z - mu) / std) ** 2) )
        assert(normal_dist.all() >= 0)

        return torch.log(normal_dist).sum(dim=-1)

In [None]:
class PPO(nn.Module):
    def __init__(self, epochs, training_iterations, batch_size, trajectory_length, n_actors, env, in_features, out_features, hidden_features, device, actor_learning_rate, critic_learning_rate, gamma, lambda_, epsilon, std, beta, d_targ, mode, n_nets, omega, omega12):
        super().__init__()
        
        self.epochs = epochs
        self.training_iterations = training_iterations
        self.batch_size = batch_size
        self.trajectory_length = trajectory_length
        self.n_actors = n_actors
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.gamma = gamma
        self.lambda_ = lambda_
        self.epsilon = epsilon
        self.beta = beta
        self.omega = omega
        self.omega12 = omega12
        self.d_targ = d_targ
        self.env = env
        self.device = device
        self.std = std
        self.actor_learning_rate = actor_learning_rate
        self.mode = mode
        self.n_nets = n_nets
        self.loss_coeff = beta * (omega ** torch.arange(0, self.n_nets)).to(device)

        if n_nets > 1:
            self.actor_list = [Policy(in_features, 
                                out_features, 
                                hidden_features, 
                                (self.omega**(-i))*actor_learning_rate,
                                std,
                                device
                                ).to(self.device) for i in range(self.n_nets)]

        else:
            self.actor = Policy(in_features, out_features, hidden_features, actor_learning_rate, std, device)
        
        self.critic = ValueFunction(in_features, out_features, hidden_features, critic_learning_rate)

    
    def train_model(self):

        N = self.n_actors  #number of actors
        T = self.trajectory_length # trajectory length
        
        for i in range(self.training_iterations):
            dataset = []
            print(f"[train]: starting dataset creation at iteration n {i}")

            with torch.no_grad():
                adv_list = []
                cum_reward = 0
                for _ in range(N): #for each actor

                    # initialize first state
                    s_prime, _ = self.env.reset()
                    s_prime = torch.tensor(s_prime, dtype=torch.float32).to(self.device)

                    trajectory = []
                    done = False

                    for t in range(T):

                        action, mu, std = self.actor(s_prime)
                        log_policy = self.actor.get_log_probability(action, mu, std)

                        s, reward, terminated, truncated, _ = self.env.step(action.squeeze(0).cpu().detach().numpy())
                        s = torch.tensor(s, dtype=torch.float32).to(self.device)
                        reward = torch.tensor([[reward]], dtype=torch.float32).to(self.device)
                        s_prime = s_prime.unsqueeze(0)
                        trajectory.append([s_prime, action, reward, log_policy])
                        s_prime = s
                        cum_reward += reward

                        done = terminated or truncated
                        if done:
                            break


                    dynamic_target = 0 if done else self.critic(s)
                    for t in range(len(trajectory)-1, -1, -1): #I want the range from [T-1 to 0]
                        
                        dynamic_target = dynamic_target*self.gamma + trajectory[t][2] #taking the reward
                        advantage = dynamic_target - self.critic(trajectory[t][0])
                        trajectory[t] = tuple(trajectory[t] + [dynamic_target.unsqueeze(0), advantage.unsqueeze(0)])

                        dataset.append(trajectory[t])
                        adv_list.append(advantage)

                adv_std, adv_mean = torch.std_mean(torch.tensor(adv_list))
                print(f"[training]: cum reward {cum_reward}")
            
            print(f"[training]: ending dataset creation with dataset size {len(dataset)}")

            self.actor.zero_grad()
            self.critic.zero_grad()
            # Starts the training process
            for e in range(self.epochs):
                
                print(f"[train]: epoch n {e}")
                avg_loss_value = 0
                avg_loss_ppo = 0
                rd.shuffle(dataset) #shuffle in-place
                
                assert(self.batch_size <= len(dataset))

                for mini_idx in range(0, len(dataset), self.batch_size):
                    
                    # form mini_batch
                    mini_batch = dataset[mini_idx: mini_idx+self.batch_size]

                    state_mini = torch.stack(list(map(lambda elem: elem[0].squeeze(), mini_batch)))
                    action_mini = torch.stack(list(map(lambda elem: elem[1].squeeze(), mini_batch)))
                    log_policy_mini = torch.stack(list(map(lambda elem: elem[3].squeeze(), mini_batch)))
                    advantage_mini = torch.stack(list(map(lambda elem: elem[4].squeeze(), mini_batch)))
                    target_mini = torch.stack(list(map(lambda elem: elem[5].squeeze(), mini_batch)))
                    
                    # Normalize advantage_mini
                    advantage_mini = ((advantage_mini-adv_mean) / (adv_std+0.00001))

                    _, mu_mini, std_mini = self.actor(state_mini) # std is a scalar!
                    new_log_policy_mini = self.actor.get_log_probability(action_mini, mu_mini, std_mini)   

                    new_value_mini = self.critic(state_mini)
                    
                    self.actor.optimizer.zero_grad()
                    self.critic.optimizer.zero_grad()
                    
                    if self.mode == 'clip':
                        loss_ppo = self.loss_clip(new_log_policy_mini, log_policy_mini, advantage_mini)
                    elif (self.mode == 'kl_fixed') or (self.mode == 'kl_adaptive'):
                        loss_ppo = self.loss_kl(new_log_policy_mini, log_policy_mini, advantage_mini)

                    loss_value = self.loss_value(new_value_mini, target_mini)

                    avg_loss_ppo += loss_ppo
                    avg_loss_value += avg_loss_value

                    loss_ppo.backward()
                    loss_value.backward()
                    
                    self.actor.optimizer.step()
                    self.critic.optimizer.step()


                total_minibatch = math.floor(len(dataset) // self.batch_size)
                print(f"[avg actor loss]: {avg_loss_ppo / total_minibatch} \t[critic loss]: {loss_value / total_minibatch}")


            self.save_parameters("model"+str(i)+".pt")


    def train_model_pc(self):

        N = self.n_actors  #number of actors
        T = self.trajectory_length # trajectory length
        
        for i in range(self.training_iterations):
            dataset = []
            print(f"[train]: starting dataset creation at iteration n {i}")

            with torch.no_grad():
                adv_list = []
                cum_reward = 0
                for _ in range(N): #for each actor

                    # initialize first state
                    s_prime, _ = self.env.reset()
                    s_prime = torch.tensor(s_prime, dtype=torch.float32).to(self.device)
                    
                    trajectory = []
                    done = False

                    for t in range(T):
                        
                        actions_list_return = [self.actor_list[i](s_prime) for i in range(self.n_nets)] # [(action, mu, std), ..., ]
                        log_policy_list = [self.actor_list[i].get_log_probability(actions_list_return[i][0],
                                                                                actions_list_return[i][1],
                                                                                actions_list_return[i][2]) for i in range(self.n_nets)] # [log_policy, ..., ]
                        action_zero = actions_list_return[0][0]

                        s, reward, terminated, truncated, _ = self.env.step(action_zero.squeeze(0).cpu().detach().numpy())
                        s = torch.tensor(s, dtype=torch.float32).to(self.device)
                        reward = torch.tensor([[reward]], dtype=torch.float32).to(self.device)
                        s_prime = s_prime.unsqueeze(0)
                        trajectory.append([s_prime, action_zero, reward, log_policy_list])
                        s_prime = s
                        cum_reward += reward

                        done = terminated or truncated
                        if done:
                            break

                    dynamic_target = 0 if done else self.critic(s)
                    for t in range(len(trajectory)-1, -1, -1): #I want the range from [T-1 to 0]
                        
                        dynamic_target = dynamic_target*self.gamma + trajectory[t][2] #taking the reward
                        advantage = dynamic_target - self.critic(trajectory[t][0])
                        trajectory[t] = tuple(trajectory[t] + [dynamic_target.unsqueeze(0), advantage.unsqueeze(0)])

                        dataset.append(trajectory[t])
                        adv_list.append(advantage)

                adv_std, adv_mean = torch.std_mean(torch.tensor(adv_list))
                print(f"[training]: cum reward {cum_reward}")
            
            print(f"[training]: ending dataset creation with dataset size {len(dataset)}")

            [actor.zero_grad() for actor in self.actor_list]
            self.critic.zero_grad()
            # Starts the training process
            for e in range(self.epochs):
                
                print(f"[train]: epoch n {e}")
                avg_loss_value = 0
                avg_loss_ppo = 0
                rd.shuffle(dataset) #shuffle in-place
                
                assert(self.batch_size <= len(dataset))

                for mini_idx in range(0, len(dataset), self.batch_size):
                    
                    # form mini_batch
                    mini_batch = dataset[mini_idx: mini_idx+self.batch_size]

                    state_mini = torch.stack(list(map(lambda elem: elem[0].squeeze(), mini_batch)))
                    action_mini = torch.stack(list(map(lambda elem: elem[1].squeeze(), mini_batch)))
                    total_log_policy_list = torch.t(torch.stack(list(map(lambda elem: torch.cat(elem[3]), mini_batch)))) # size (nets, batch)
                    advantage_mini = torch.stack(list(map(lambda elem: elem[4].squeeze(), mini_batch)))
                    target_mini = torch.stack(list(map(lambda elem: elem[5].squeeze(), mini_batch)))
                    
                    # Normalize advantage_mini
                    advantage_mini = ((advantage_mini-adv_mean) / (adv_std+0.00001))
                    new_actions = [actor(state_mini) for actor in self.actor_list]

                    total_new_log_policy_list = torch.stack([self.actor_list[i].get_log_probability(action_mini, new_actions[i][1], new_actions[i][2]) for i in range(self.n_nets)]) # size (nets, batch)

                    new_value_mini = self.critic(state_mini)
                    
                    [actor.zero_grad() for actor in self.actor_list]
                    self.critic.optimizer.zero_grad()
                    
                    # returns a loss for every net
                    loss_ppo = self.loss_kl(total_new_log_policy_list, total_log_policy_list, advantage_mini)
                    loss_value = self.loss_value(new_value_mini, target_mini)

                    avg_loss_ppo += loss_ppo
                    avg_loss_value += loss_value

                    loss_ppo.backward()
                    loss_value.backward()
                    
                    [actor.optimizer.step() for actor in self.actor_list]
                    self.critic.optimizer.step()


                total_minibatch = math.floor(len(dataset) // self.batch_size)
                print(f"[avg actor loss]: {avg_loss_ppo / total_minibatch} \t[critic loss]: {avg_loss_value / total_minibatch}")


            self.save_parameters("partial_models/model"+str(i)+".pt")

                

    def loss_value(self, value, target):
        #MSE
        return torch.mean((value-target)**2)

    def loss_clip(self, new_log_policy_mini, log_policy_mini, advantage_mini):

        prob_mini = torch.exp(new_log_policy_mini - log_policy_mini)
        prob_adv = prob_mini*advantage_mini
        clip_ = torch.clip(prob_mini, 1-self.epsilon, 1+self.epsilon)*advantage_mini
        return -torch.min(prob_adv, clip_).mean()
    
    def loss_kl1(self, new_log_policy_mini, log_policy_mini, advantage_mini):
        
        prob_mini = torch.exp(new_log_policy_mini - log_policy_mini)
        prob_adv = prob_mini * advantage_mini
        d = log_policy_mini - new_log_policy_mini

        if self.mode == 'kl_adaptive':
            if d.detach().mean() < (self.d_targ / 1.5):
                self.beta = self.beta / 2
            elif d.detach().mean() > (self.d_targ * 1.5):
                self.beta = self.beta * 2

        return -(prob_adv - self.beta*d).mean()
    
    def loss_kl(self, stack_new, stack_old, advantage_mini):
        
        # print("total_new_log_policy_list", stack_new)
        # print("total_log_policy_list", stack_old)
        # stack new and stack old have shape (n_net, batch_size)
        # We compute the policy gradient based on first net prob and adv
        new_log_policy_mini = stack_new[0, :]
        log_policy_mini = stack_old[0, :]
        prob_mini = torch.exp(new_log_policy_mini - log_policy_mini)

        L_pg = prob_mini * advantage_mini

        kl_stack = stack_new - stack_old
        L_ppo = torch.sum(self.loss_coeff * torch.t(kl_stack), dim=1)

        # print("stack_new", torch.t(stack_new))
        # print("stack_old", torch.t(stack_old))
        # print("stack_new_sliced", torch.t(stack_new)[:, 1:])
        # print("stack_old_sliced", torch.t(stack_old)[:, 0:stack_old.shape[0]-1])


        L_casc_init = self.omega12 * (torch.t(stack_new)[:, 1] - torch.t(stack_old)[:, 2])
        kl_sub_previous = torch.t(stack_new)[:, 1:] - torch.t(stack_old)[:, 0:stack_old.shape[0]-1]
        # I'm appending to the matrix a row which is equal to the last row. At the end i will have a matrix with
        # We also don't need the first two columns of old
        kl_sub_successive = torch.t(stack_new)[:, 1:] - torch.t(torch.cat((stack_old, stack_new[stack_new.shape[0]-1, :].unsqueeze(0)), 0))[:, 2:]
        L_casc = L_casc_init + torch.sum(self.omega*kl_sub_previous + kl_sub_successive, dim=1) # summing on net dimension

        return -(L_pg - L_ppo - L_casc).mean()
        

    def extract_states_prime(self, trajectory):
        return list(map(lambda x: x[0], trajectory))
    
    def save_parameters(self, path):
        torch.save(self.state_dict(), path)



In [22]:
env = gym.make('HalfCheetah-v5', ctrl_cost_weight=0.1)

epochs = 10
training_iterations = 20
batch_size = 64
trajectory_length = 500
n_actors = 10
in_features = env.observation_space.shape[0]
out_features = env.action_space.shape[0]
hidden_features = 64
actor_learning_rate = 5e-4
critic_learning_rate = 5e-4
gamma = 0.99
lambda_ = 0.95
epsilon = 0.2
beta = 0.5
omega = 1
omega12 = 1
d_targ = 0.01
std = 0.5
n_nets = 7

device = "mps"
mode = "pc"
if mode == "pc":
    assert(n_nets > 1)
modes = ["kl_fixed", "kl_adaptive", "clip", "pc"]
assert(mode in modes)



In [23]:
ppo = PPO(epochs=epochs, 
          training_iterations=training_iterations,
          batch_size=batch_size,
          trajectory_length=trajectory_length, 
          n_actors=n_actors,
          env=env,
          in_features=in_features,
          out_features=out_features,
          hidden_features=hidden_features,
          device=device,
          actor_learning_rate=actor_learning_rate,
          critic_learning_rate=critic_learning_rate,
          gamma=gamma,
          lambda_=lambda_,
          epsilon=epsilon,
          beta = beta,
          d_targ=d_targ,
          std=std,
          mode=mode,
          n_nets=n_nets,
          omega=omega,
          omega12=omega12,
        )
ppo.to(device)

PPO(
  (critic): ValueFunction(
    (critic): Sequential(
      (0): Linear(in_features=17, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=1, bias=True)
      (5): ReLU()
    )
  )
)

In [24]:
#ppo.load_state_dict(torch.load("final_pc.pt"))
ppo.train_model_pc()
ppo.save_parameters("final_pc.pt")

[train]: starting dataset creation at iteration n 0
[training]: cum reward tensor([[-1631.0850]], device='mps:0')
[training]: ending dataset creation with dataset size 5000
[train]: epoch n 0
loss pg:  tensor(-0.0004, device='mps:0')
loss ppo:  tensor(-6.8986, device='mps:0')
loss casc:  tensor(-26.8044, device='mps:0')
loss pg:  tensor(-0.1911, device='mps:0')
loss ppo:  tensor(-8.7455, device='mps:0')
loss casc:  tensor(-32.6794, device='mps:0')
loss pg:  tensor(0.0141, device='mps:0')
loss ppo:  tensor(-8.7591, device='mps:0')
loss casc:  tensor(-33.8067, device='mps:0')
loss pg:  tensor(0.0104, device='mps:0')
loss ppo:  tensor(-7.8589, device='mps:0')
loss casc:  tensor(-30.1538, device='mps:0')
loss pg:  tensor(0.1816, device='mps:0')
loss ppo:  tensor(-10.1987, device='mps:0')
loss casc:  tensor(-40.2746, device='mps:0')
loss pg:  tensor(-0.1065, device='mps:0')
loss ppo:  tensor(-11.8191, device='mps:0')
loss casc:  tensor(-45.0157, device='mps:0')
loss pg:  tensor(0.0426, devi

KeyboardInterrupt: 

In [None]:
device = 'cpu'

ppo = PPO(epochs=epochs, 
          training_iterations=training_iterations,
          batch_size=batch_size,
          trajectory_length=trajectory_length, 
          n_actors=n_actors,
          env=env,
          in_features=in_features,
          out_features=out_features,
          hidden_features=hidden_features,
          device=device,
          actor_learning_rate=actor_learning_rate,
          critic_learning_rate=critic_learning_rate,
          gamma=gamma,
          lambda_=lambda_,
          epsilon=epsilon,
          beta = beta,
          d_targ=d_targ,
          std=std,
          mode=mode,
          n_nets=n_nets,
          omega=omega,
          omega12=omega12,
        )

ppo.load_state_dict(torch.load("model12.pt"))

env = gym.make('HalfCheetah-v5', ctrl_cost_weight=0.1, render_mode="human")
#env = gym.make('HalfCheetah-v5', ctrl_cost_weight=0.1)
rewards = []
for episode in range(10):
    print(f"ep n {episode}", "\r")
    total_reward = 0
    done = False
    s, _ = env.reset()
    while not done:
        s = torch.tensor(s, dtype=torch.float32)
        z, mu, std = ppo.actor_list[0](s)
        s, reward, terminated, truncated, info = env.step(z.squeeze().cpu().detach().numpy())
        s = torch.tensor(s, dtype=torch.float32)
        done = terminated or truncated
        total_reward += reward

ep n 0 


2026-02-08 15:39:33.937 Python[17355:9387308] +[IMKClient subclass]: chose IMKClient_Legacy
2026-02-08 15:39:33.937 Python[17355:9387308] +[IMKInputSession subclass]: chose IMKInputSession_Legacy
  s = torch.tensor(s, dtype=torch.float32)


ep n 1 


KeyboardInterrupt: 

: 