In [None]:
import os
import logging
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Beta
from torch import optim
import torchvision.models as models
from tqdm import tqdm

import gymnasium as gym
from gymnasium.utils.save_video import save_video

logging.getLogger().setLevel(logging.ERROR)

In [None]:
class CNN_Net(nn.Module):

    def __init__(
        self,
        device: torch.device,
    ) -> None:

        super().__init__()
        self.device = device

        self.CNN = nn.Sequential(nn.Conv2d(3, 8, kernel_size = 3, stride = 1, padding = 1),
                                 nn.BatchNorm2d(8),
                                 nn.ReLU(),
                                 nn.Conv2d(8, 8, kernel_size = 3, stride = 1, padding = 1),
                                 nn.BatchNorm2d(8),
                                 nn.ReLU(),
                                 nn.Conv2d(8, 16, kernel_size = 3, stride = 2, padding = 1),
                                 nn.BatchNorm2d(16),
                                 nn.ReLU(),
                                 nn.Conv2d(16, 16, kernel_size = 3, stride = 1, padding = 1),
                                 nn.BatchNorm2d(16),
                                 nn.ReLU(),
                                 nn.Conv2d(16, 32, kernel_size = 3, stride = 2, padding = 1),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.Conv2d(32, 32, kernel_size = 3, stride = 1, padding = 1),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.Conv2d(32, 32, kernel_size = 1, stride = 1, padding = 1),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.AdaptiveAvgPool2d((4,4)),
                                 )

        self.c_g = nn.Sequential(nn.Linear(10, 32),
                                nn.ReLU(),
                                nn.Linear(32, 32),
                                nn.ReLU(),
                                nn.Linear(32, 32),
                               )

        nn.init.zeros_(self.c_g[4].weight)
        nn.init.constant_(self.c_g[4].bias, 0)

        self.s_g = nn.Sequential(nn.Linear(10, 16),
                        nn.ReLU(),
                        nn.Linear(16, 16),
                        nn.ReLU(),
                        nn.Linear(16, 16),
                       )

        nn.init.zeros_(self.s_g[4].weight)
        nn.init.constant_(self.s_g[4].bias, 0)

        self.FC = nn.Sequential(nn.Linear(32*16, 128),
                                nn.ReLU(),
                                nn.LayerNorm(128),
                               )

        self.to(device)

    @torch.no_grad()
    def _preprocess(self, x: np.array) -> (torch.tensor, torch.tensor):
        # x: (n, 96, 96, 3) uint8 or float
    
        x = torch.from_numpy(x).float().to(self.device)
    
        x_map = x[:, :84, :, :]   #(n, 84, 96, 3)
    
        speed = torch.sum(x[:, 91:94, 12:15, :], dim = (1,2,3))/4317    #(n, )
        ABS = torch.sum(x[:, 91:94, 17:27, :], dim = (1,2,3))/4851      #(n, )
        left_steer = torch.sum(x[:, 86:92, 38:49, 1], dim = (1,2))/12563       #(n, )
        right_steer = torch.sum(x[:, 86:92, 48:59, 1], dim = (1,2))/12543      #(n, )
        left_gyro = torch.sum(x[:, 86:92, 58:73, 0], dim = (1,2))/14000        #(n, )
        right_gyro = torch.sum(x[:, 86:92, 72:87, 0], dim = (1,2))/14000       #(n, )

        x_gauge = torch.stack((speed, ABS, left_steer, right_steer, left_gyro, right_gyro, 
                               speed*left_steer, speed*right_steer,
                              ABS*left_steer, ABS*right_steer), dim = 1)        #(n, 10)

        x_map = x_map.permute(0, 3, 1, 2)      #(n, 3, 84, 96)
    
        return x_map, x_gauge

    def forward(self, x: np.array) -> torch.tensor:

        x_map, x_gauge = self._preprocess(x)
        x_map = self.CNN(x_map)
        
        c_g = self.c_g(x_gauge)
        c_g = 1 + 0.2*torch.tanh(c_g)
        c_g = c_g.reshape(-1,32,1,1)

        s_g = self.s_g(x_gauge)
        s_g = 1 + 0.2*torch.tanh(s_g)
        s_g = s_g.reshape(-1,1,4,4)

        x = x_map*c_g*s_g

        x = torch.flatten(x, start_dim = 1)

        x = self.FC(x)

        return x

class Critic_Net(nn.Module):

    def __init__(
        self,
        device: torch.device,
    ) -> None:

        super().__init__()
        self.device = device

        critic_layers = [
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        ]                         

        self.critic = nn.Sequential(*critic_layers).to(device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        state_values = self.critic(x).squeeze(-1)                
        
        return state_values

class Actor_Net(nn.Module):

    def __init__(
        self,
        device: torch.device,
    ) -> None:

        super().__init__()
        self.device = device

        actor_shared_layers = [
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        ]

        self.actor_shared = nn.Sequential(*actor_shared_layers)
        self.actor_steers = nn.Sequential(nn.Linear(128, 2))
        self.actor_accels = nn.Sequential(nn.Linear(128, 2))

        nn.init.zeros_(self.actor_steers[0].weight)
        nn.init.constant_(self.actor_steers[0].bias, -3)

        nn.init.zeros_(self.actor_accels[0].weight)
        nn.init.constant_(self.actor_accels[0].bias, -3)

        self.to(device)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        actions_shared = self.actor_shared(x)       #[n_evns, 32]

        actions_steers_ab = self.actor_steers(actions_shared) # shape: [n_envs, 2]
        actions_accels_ab = self.actor_accels(actions_shared) # shape: [n_envs, 2]

        actions_steers_ab = 1 + 1e-3 + F.softplus(actions_steers_ab)
        actions_accels_ab = 1 + 1e-3 + F.softplus(actions_accels_ab)
                
        return  actions_steers_ab, actions_accels_ab


class PPO_Net(nn.Module):

        def __init__(
        self,
        device: torch.device,
    ) -> None:

            super().__init__()

            self.CNN_Net = CNN_Net(device)
            self.Critic_Net = Critic_Net(device)
            self.Actor_Net = Actor_Net(device)

    
class PPO():

    def __init__(
        self,
        device: torch.device,
        n_envs: int,
    ):

        self.device = device
        self.n_envs = n_envs

        self.PPO_Net = PPO_Net(device)
        
    def optimizer(self, critic_lr, actor_lr, cnn_lr) -> None:

        self.PPO_optimizer = torch.optim.AdamW([{"params": self.PPO_Net.CNN_Net.parameters(), "lr":cnn_lr, "weight_decay": 0.0},
                                              {"params": self.PPO_Net.Critic_Net.parameters(), "lr":critic_lr, "weight_decay":0.0},
                                              {"params": self.PPO_Net.Actor_Net.parameters(), "lr":actor_lr, "weight_decay": 0.0}])

        
    def select_action(
        self, x: np.ndarray
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        x = self.PPO_Net.CNN_Net(x)

        state_values = self.PPO_Net.Critic_Net(x)
        actions_steers_ab, actions_accels_ab = self.PPO_Net.Actor_Net(x)

        c_actions_steers = actions_steers_ab[:,0] + actions_steers_ab[:,1]
        c_actions_accels = actions_accels_ab[:,0] + actions_accels_ab[:,1]

        regulation = F.relu(c_actions_steers - c_max).pow(2) + F.relu(c_actions_accels - c_max).pow(2)

        steerings_dist = Beta(actions_steers_ab[:, 0], actions_steers_ab[:, 1])
        accels_dist = Beta(actions_accels_ab[:, 0], actions_accels_ab[:, 1])

        steerings = steerings_dist.rsample()
        accels = accels_dist.rsample()
        
        # total log prob per environment
        action_log_probs = (
            steerings_dist.log_prob(steerings) + accels_dist.log_prob(accels)
        )
        
        # entropy (no correction needed)
        entropy = (
            steerings_dist.entropy() + accels_dist.entropy()
        )

        steerings = 2*steerings - 1
        accels = 2*accels - 1

        gases = accels*(accels >= 0)
        brakes = -accels*(accels < 0)       

        actions = torch.stack([steerings, gases, brakes], dim = 1) # shape: [n_envs, 3]
 
        return (actions, action_log_probs, state_values, entropy, regulation)

    def eval_action(
        self, x: np.ndarray, actions: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        x = self.PPO_Net.CNN_Net(x)

        state_values = self.PPO_Net.Critic_Net(x)
        actions_steers_ab, actions_accels_ab = self.PPO_Net.Actor_Net(x)

        c_actions_steers = actions_steers_ab[:,0] + actions_steers_ab[:,1]
        c_actions_accels = actions_accels_ab[:,0] + actions_accels_ab[:,1]

        regulation = F.relu(c_actions_steers - c_max).pow(2) + F.relu(c_actions_accels - c_max).pow(2)

        steerings_dist = Beta(actions_steers_ab[:, 0], actions_steers_ab[:, 1])
        accels_dist = Beta(actions_accels_ab[:, 0], actions_accels_ab[:, 1])

        steerings = (actions[:,0] + 1)/2
        accels = (actions[:,1] - actions[:,2] + 1)/2

        action_log_probs = (
            steerings_dist.log_prob(steerings.clamp(1e-8, 1-1e-8)) + accels_dist.log_prob(accels.clamp(1e-8, 1-1e-8))
        )
        
        entropy = (
            steerings_dist.entropy() + accels_dist.entropy()
        )

        return action_log_probs, state_values, entropy, regulation 

    def get_advs_returns(
        self,
        rewards: torch.Tensor,
        value_preds: torch.Tensor,
        masks: torch.Tensor,
        masks_d: torch.Tensor,
        gamma: float,
        lam: float,
        active_indices: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        T = len(rewards)
        advantages = torch.zeros(T, self.n_envs, device=self.device)

        gae = torch.zeros(self.n_envs, device=self.device)
        for t in reversed(range(T)):
            td_error = (
                rewards[t] + gamma*masks[t]*value_preds[t+1] - value_preds[t]
            )
            gae = td_error + gamma*lam*masks_d[t]*gae
            advantages[t] = gae

        returns = advantages + value_preds[:-1]

        advantages = advantages.flatten().detach()[active_indices]
        returns = returns.flatten().detach()[active_indices]
        
        adv_n = (advantages - advantages.mean())/(advantages.std() + 1e-4).detach()
        adv_n = adv_n.clamp(-3.0, 3.0)

        return advantages, adv_n, returns
        
    def get_losses(
        self,
        action_log_probs: torch.Tensor,
        active_indices: torch.Tensor,
        ent_coef: float,
        reg_coef: float,
        states: np.ndarray,
        actions: torch.Tensor,
        adv_n: torch.Tensor,
        returns: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        states = states.reshape(-1, 96, 96, 3)                                #(n_step_per_update*n_env, 96, 96, 3)
        actions = actions.reshape(-1, 3).detach()                             #(n_step_per_update*n_env, 3)
        action_log_probs_old = action_log_probs.flatten().detach()            #(n_step_per_update*n_env, )   
                                                      
        action_log_probs_new, state_values, entropy, regulation = self.eval_action(states, actions)           #(n_step_per_update*n_env, )

        critic_loss = F.smooth_l1_loss(state_values[active_indices], returns.detach(), beta = 1.0)
        #critic_loss = (state_values[active_indices]-returns.detach()).pow(2).mean()

        adv_n = adv_n.detach()
        
        ratio = torch.exp((action_log_probs_new - action_log_probs_old)[active_indices])
        surr1 = ratio*adv_n
        surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_n

        actor_loss = -torch.min(surr1, surr2).mean()
        
        entropy_loss = - ent_coef*entropy[active_indices].mean()
        regulation_loss = reg_coef*regulation[active_indices].mean()
        
        return critic_loss, actor_loss, entropy_loss, regulation_loss

    def update_parameters(
        self, 
        n_iter,
        rewards: torch.Tensor,
        value_preds: torch.Tensor,
        action_log_probs: torch.Tensor,
        masks: torch.Tensor,
        masks_d: torch.Tensor,
        active_indices: torch.Tensor,
        gamma: float,
        lam: float,
        ent_coef: float,
        reg_coef: float,
        states: np.array,
        actions: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        advantages, adv_n, returns = self.get_advs_returns(rewards, value_preds, masks, masks_d, gamma, lam, active_indices)

        critic_loss_r = 0
        policy_loss_r = 0
        
        for i in range(n_iter):

            critic_loss, actor_loss, entropy_loss, regulation_loss = self.get_losses(action_log_probs, active_indices, ent_coef, reg_coef,
                                                                                     states, actions, adv_n, returns)
    
            total_loss = cv*critic_loss + actor_loss + entropy_loss + regulation_loss
    
            self.PPO_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.PPO_Net.parameters(), max_norm=0.5)
            self.PPO_optimizer.step()

            critic_loss_r += critic_loss.detach()
            policy_loss_r += (actor_loss + entropy_loss + regulation_loss).detach()

        return critic_loss_r/n_iter, policy_loss_r/n_iter
        


In [None]:
# environment hyperparams

n_envs = 8

Critic_lr = 1e-3   
Actor_lr = 1e-4    
CNN_lr = 5e-4     

cv = 1
n_iter = 4
clip_eps = 0.1   

reg_coef = 1e-5
c_max = 20      


base_ent_coef = 0.01 
early_step_max = 500
mag_ent = 1

n_updates = 3000
n_steps_per_update = 256
randomize_domain = False


#agent hyperparams
lam = 0.95 # hyper parameter for GAE
gamma = 0.99



In [None]:
# set the device
use_cuda = True
if use_cuda:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# init the agent

agent = PPO(device, n_envs)
agent.optimizer(Critic_lr, Actor_lr, CNN_lr)

In [None]:
# environment setup

envs = gym.make_vec("CarRacing-v3", num_envs=n_envs, vectorization_mode="async")

states, info = envs.reset(seed=42)

rewards_update = []
critic_losses = []
policy_losses = []
entropies = []
prog_i = 0

skip_nums = np.full((n_envs, ), 50, dtype = np.int32)


In [None]:
agent.PPO_Net.train()


# use tqdm to get a progress bar for training
for sample_phase in tqdm(range(n_updates)):
    # we don't have to reset the envs, they just continue playing
    # until the episode is over and then reset automatically

    prog_i += 1
    
    progress = min(prog_i / early_step_max, 1.0)  
    mult = mag_ent - (mag_ent - 1.0) * progress         
    ent_coef = base_ent_coef * mult

    ep_rewards = torch.zeros(n_steps_per_update, n_envs, device = device)
    ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device = device)
    ep_entropy = torch.zeros(n_steps_per_update, n_envs, device = device)
    masks = torch.zeros(n_steps_per_update, n_envs, device = device)
    masks_d = torch.zeros(n_steps_per_update, n_envs, device = device)
    ep_skip_masks = torch.zeros(n_steps_per_update, n_envs, device = device)
    ep_states = np.zeros((n_steps_per_update, n_envs, 96, 96, 3))
    ep_actions = torch.zeros(n_steps_per_update, n_envs, 3, device = device)
    ep_value_preds = torch.zeros(n_steps_per_update+1, n_envs, device = device)
        

    # play n steps in our parallel environments to collect data
    for step in range(n_steps_per_update):

        ep_states[step] = states

        # select an action A_{t} using S_{t} as input for the agent
        actions, action_log_probs, state_value_preds, entropy, regulation = agent.select_action(
            states
        )

        ep_actions[step] = actions

        actions = actions.detach().cpu().numpy()
        
        active_envs = np.array((skip_nums == 0), dtype = np.int32)
        skip_masks = torch.tensor(active_envs, device=device)
        ins_inactive_envs = np.where(active_envs == 0)

        actions[ins_inactive_envs] = [0.0, 0.0, 0.0]

        # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
        states, rewards, terminated, truncated, infos = envs.step(
            actions
        )

        ep_value_preds[step] = state_value_preds
        ep_rewards[step] = torch.tensor(rewards, device=device)
        ep_action_log_probs[step] = action_log_probs
        ep_entropy[step] = entropy
        ep_skip_masks[step] = skip_masks

        masks[step] = torch.tensor([not term for term in terminated])

        done = np.array((terminated | truncated), dtype = np.int32)   # (n_envs, )
        ids_done = np.where(done == 1)
        ins_skip = np.where(skip_nums >= 1)
        
        skip_nums[ins_skip] -= 1
        skip_nums[ids_done] = 50

        masks_d[step] = torch.tensor([not term for term in done])
    
    with torch.no_grad():

        last_states = agent.PPO_Net.CNN_Net(states)
        last_value_preds = agent.PPO_Net.Critic_Net(last_states)

    ep_value_preds[-1] = last_value_preds

    active_indices = torch.where(ep_skip_masks.flatten() == 1)[0]
    
        
    # update the actor and critic network
    critic_loss, policy_loss = agent.update_parameters(
        n_iter,
        ep_rewards,
        ep_value_preds,
        ep_action_log_probs,
        masks,
        masks_d,
        active_indices,
        gamma,
        lam,
        ent_coef,
        reg_coef,
        ep_states,
        ep_actions,
    )
    
    # Log the losses and entropy
    rewards_update.append(np.mean(ep_rewards.detach().cpu().numpy()))
    critic_losses.append(critic_loss.detach().cpu().numpy())
    policy_losses.append(policy_loss.detach().cpu().numpy())
    entropies.append(ep_entropy.detach().mean().cpu().numpy())




In [None]:
""" plot the results """

rolling_length = 50
fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(12, 4))
fig.suptitle(
    f"Training plots for {agent.__class__.__name__} in the CarRacing-v3 environment \n \
    (n_envs={n_envs}, n_steps_per_update={n_steps_per_update}, randomize_domain={randomize_domain})"
)

# episode return
axs[0].set_title("Rewards")
rewards_update_moving_average = (
    np.convolve(np.array(rewards_update), np.ones(rolling_length), mode="valide")
    / rolling_length
)
axs[0].plot(rewards_update_moving_average)
axs[0].set_xlabel("Number of updates")

# entropy
axs[1].set_title("Entropy")
entropy_moving_average = (
    np.convolve(np.array(entropies), np.ones(rolling_length), mode="valide")
    / rolling_length
)
axs[1].plot(entropy_moving_average)
axs[1].set_xlabel("Number of updates")

# critic loss
axs[2].set_title("Critic Loss")
critic_losses_moving_average = (
    np.convolve(
        np.array(critic_losses), np.ones(rolling_length), mode="valid"
    )
    /rolling_length
)
axs[2].plot(critic_losses_moving_average)
axs[2].set_xlabel("Number of updates")
#axs[2].set_yscale('log')

# actor loss
axs[3].set_title("policy Loss")
policy_losses_moving_average = (
    np.convolve(
        np.array(policy_losses), np.ones(rolling_length), mode="valid"
    )
    /rolling_length
)
axs[3].plot(policy_losses_moving_average)
axs[3].set_xlabel("Number of updates")

plt.tight_layout()
plt.show()
    
    

In [None]:
save_weights = True
load_weights = False

PPO_weights_path = "weights/PPO_weights_reward_0.9.hS"

if not os.path.exists("weights"):
    os.mkdir("weights")

""" save network weights """
if save_weights:
    torch.save(agent.PPO_Net.state_dict(), PPO_weights_path)

""" Load network weights """

if load_weights:
    agent.PPO_Net.load_state_dict(torch.load(PPO_weights_path))
    agent.PPO_Net.eval()



In [None]:
""" play a showcase episode """

video_folder = "video"

if not os.path.exists(video_folder):
    os.mkdir(video_folder)

env_v = gym.make("CarRacing-v3", render_mode="rgb_array_list")

agent.PPO_Net.eval()

# get an initial state
state_v, info_v = env_v.reset()
for i in range(50):
    state_v, reward_v, terminated_v, truncated_v, info_v = env_v.step(np.array([0.0, 0.0, 0.0]))

# play one episode
done_v = False
while not done_v:
    # select an action A_{t} using S_{t} as input for the agent
    with torch.no_grad():
        action_v,_,_,_,_ = agent.select_action(state_v[None,:])

    # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
    state_v, reward_v, terminated_v, truncated_v, info_v = env_v.step(action_v.detach().cpu().numpy().squeeze())

    # update if the environment is done
    done_v = terminated_v or truncated_v

# make a video
save_video(frames=env_v.render(), video_folder=video_folder, fps=env_v.metadata["render_fps"])

env_v.close()
