# PPO

## Policy Network

In [1]:
import gymnasium as gym
import highway_env
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

def init_layer(layer, gain = np.sqrt(2)):
  nn.init.orthogonal_(layer.weight, gain)
  nn.init.constant_(layer.bias, 0)
  return layer

# Policy network (MLP)
class MLPPolicyNetwork(nn.Module):
    def __init__(self, in_states, h1_nodes, out_actions):
        super(MLPPolicyNetwork, self).__init__()

        # Actor network
        self.actor = nn.Sequential(
            init_layer(nn.Linear(in_states, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, out_actions), gain = 0.01)
        )
        # Critic network
        self.critic = nn.Sequential(
            init_layer(nn.Linear(in_states, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, 1), gain = 1.0)
        )

    def forward(self, x):
      x = x.flatten(start_dim=1)
      logits = self.actor(x)
      value = self.critic(x)
      return logits, value


class CNNPolicyNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNPolicyNetwork, self).__init__()
        stack, height, width = input_shape

        self.shared_conv = nn.Sequential(
            nn.Conv2d(stack, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=2),
            nn.ReLU(),
        )

        with torch.no_grad():
            test = torch.zeros(1, stack, height, width)
            find_conv_size = self.shared_conv(test)
            conv_size = find_conv_size.numel()

        self.actor_fc = init_layer(nn.Linear(conv_size, num_actions), gain = 0.01)
        self.critic_fc = init_layer(nn.Linear(conv_size, 1), gain = 1.)


    def forward(self, x):

      feats = self.shared_conv(x)
      feats = torch.flatten(feats, start_dim=1)

      logits = self.actor_fc(feats)
      value = self.critic_fc(feats)

      return logits, value

In [2]:
class TrainData:
  def __init__(self):
    self.states = []
    self.actions = []
    self.rewards = []
    self.values = []
    self.next_states = []
    self.dones = []
    self.log_probs = []

  def clear(self):
    self.__init__()

## PPO Agent

In [3]:
import gymnasium as gym
import highway_env
import numpy as np
import random
import torch
import torch.optim as optim
import os
from tqdm import tqdm

import sys
sys.path.append(os.path.abspath('..'))
from metrics import Metrics

class PPOAgent():
  def __init__(self, params):

    self.device = params.get("device", torch.device("cpu"))
    self.policy = params.get("policy", "CnnPolicy")
    self.params = params

    # Learn
    self.num_iterations = params.get("num_iterations", 200)

    # Collect data
    self.num_steps = params.get("num_steps", 20)
    self.step_reward = params.get("step_reward", 0.0)

    # GAE
    self.gamma = params.get("gamma", 0.99)
    self.lamda = params.get("lamda", 0.95)


    # Update Policy
    self.clip_epsilon = params.get("clip_epsilon", 0.2)
    self.loss_coeff = params.get("loss_coeff", 0.5)

    self.epochs = params.get("epochs", 5)
    self.batch_size = params.get("batch_size", 32)

    self.learning_rate = params.get("learning_rate", 3e-4)
    self.discount = params.get("discount", 0.2) 
    self.entropy_coeff = params.get("entropy_coeff", 0.01)

    self.to_save_model = params.get("save_model", False)

    self.policy_net = None
    self.num_actions = None
    self.train_data = TrainData()

    # Metrics
    use_metrics = params.get("use_metrics", False)
    save_params = params.get("save_params", False)

    self.metrics = Metrics(self.policy, "training_results", use_metrics)
    if save_params:
        self.metrics.save_params(params)


  def create_network(self, env):
    self.num_actions = env.action_space.n

    if self.policy == "CnnPolicy":
        self.create_CNN(env)

    if self.policy == "MlpPolicy":
        self.create_MLP_Network(env)

    self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)


  def create_CNN(self, env):
        self.num_states = env.observation_space.shape
        self.num_actions = env.action_space.n

        self.policy_net = CNNPolicyNetwork(self.num_states, self.num_actions).to(self.device)

  def create_MLP_Network(self, env):
        self.num_states = env.observation_space.shape[0] * env.observation_space.shape[1]
        self.num_actions = env.action_space.n

        self.policy_net = MLPPolicyNetwork(self.num_states, self.num_states, self.num_actions).to(self.device)

  def collect_data(self, num_steps, env):
        '''
        Collect data from the environment for num_steps times, if episode ends during for loop, reset the environment.
        '''
        state = env.reset()[0]
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        state = state.unsqueeze(0)

        done = False
        episode_reward = []
        total_reward = 0.0
        num_episodes = 0
        episode_len = 0
        total_steps = 0

        self.train_data.clear()


        for _ in range(num_steps):

            action, log_prob, value = self.get_action_value(state, eval_mode=False)

            next_state, reward, done, truncated, info = env.step(action)

            next_state = torch.tensor(np.array([next_state]), dtype=torch.float32, device=self.device)

            reward += self.step_reward

            self.train_data.states.append(state)
            self.train_data.actions.append(action)
            self.train_data.log_probs.append(log_prob.item())
            self.train_data.values.append(value.item())
            self.train_data.rewards.append(reward)
            self.train_data.dones.append(done)

            episode_reward.append(reward)
            episode_len += 1

            state = next_state
            if done or truncated:
                state, _ = env.reset()
                state = torch.tensor(np.array([state]), dtype=torch.float32, device=self.device)
                done = False
                truncated = False

                num_episodes += 1
                total_steps +=  episode_len
                total_reward += sum(episode_reward) / len(episode_reward)
                episode_len = 0
                episode_reward = []

        try:
          avg_episode_len = total_steps / num_episodes
          avg_episode_reward = total_reward / num_episodes
        except:
          avg_episode_len = 0
          avg_episode_reward = 0

        return avg_episode_reward, avg_episode_len

  def gae(self, next_value=0.0):
      """
      Compute generalized advantage estimation. If last step of the episode, bootstrapping.
      """
      num_steps = len(self.train_data.rewards)
      advantages = [0] * num_steps
      gae = 0.0

      for i in reversed(range(num_steps)):
          if i == num_steps - 1:
              next_state_value = next_value
          else:
              next_state_value = self.train_data.values[i+1]

          mask = 1.0 - float(self.train_data.dones[i]) # if not done, mask = 1.0
          delta = self.train_data.rewards[i] + self.gamma * next_state_value * mask - self.train_data.values[i]
          gae = delta + self.gamma * self.lamda * mask * gae
          advantages[i] = gae

      returns = [v + a for v, a in zip(self.train_data.values, advantages)]
      return advantages, returns


  def update_policy(self, advantages, returns):
    '''
    Update policy network based on the collected data.
    '''

    states = torch.stack(self.train_data.states).to(self.device) # Stack the list of tensors
    states = states.squeeze(1)
    actions = torch.LongTensor(self.train_data.actions).to(self.device)
    old_log_probs = torch.FloatTensor(self.train_data.log_probs).to(self.device)
    advantages = torch.FloatTensor(advantages).to(self.device)
    returns = torch.FloatTensor(returns).to(self.device)

    # Normalise advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    num_states = len(states)
    step = 0

    for epoch in range(self.epochs):
        indexs = np.random.permutation(num_states)
        for i in range(0, num_states, self.batch_size):
            batch_index = indexs[i : i + self.batch_size]

            batch_states = states[batch_index]
            batch_actions = actions[batch_index]
            batch_old_log_probs = old_log_probs[batch_index]
            batch_advantages = advantages[batch_index]
            batch_returns = returns[batch_index]

            logits, value = self.policy_net(batch_states)
            dist = torch.distributions.Categorical(logits=logits)
            new_log_probs = dist.log_prob(batch_actions)

            # ratio
            ratio = torch.exp(new_log_probs - batch_old_log_probs)

            # clipped surrogate
            #L_CLIP: equation 7 in "Proximal Policy Optimization Algorithms"
            policy_loss = -torch.min(ratio * batch_advantages, torch.clip(ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon)* batch_advantages).mean()


            # value loss
            value_loss = (value.squeeze() - batch_returns).pow(2).mean()

            # entropy
            entropy = dist.entropy().mean()

            loss = policy_loss + self.loss_coeff * value_loss - self.entropy_coeff * entropy

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.metrics.add("update_policy/policy_loss", policy_loss.item(), step)
            self.metrics.add("update_policy/value_loss", value_loss.item(), step)
            self.metrics.add("update_policy/entropy", entropy.item(), step)
            self.metrics.add("update_policy/total_loss", loss.item(), step)

            step += 1


  def learn(self, env):
    self.create_network(env)

    for iteration in tqdm(range(self.num_iterations), desc="Training Model"):
      #TODO: change loop to run A actors for T times, and compute average adavantage (currently run 1 agent for num_steps times, if device == cuda use mp
      episode_reward, avg_episode_len = self.collect_data(self.num_steps, env)

      if not self.train_data.dones[-1]: # If it's the last step of the episod, bootstrapping.
          last_state = self.train_data.states[-1]

          with torch.no_grad():
              _, last_value_t = self.policy_net(last_state)
          next_value = last_value_t.item()
      else:
          next_value = 0.0

      advantages, returns = self.gae(next_value)
      self.update_policy(advantages, returns)
      
      self.metrics.add("collect_data/episode_reward", episode_reward, iteration)
      self.metrics.add("collect_data/avg_episode_len", avg_episode_len, iteration)

      if iteration % 5 == 0:
        self.evaluate(env, 10, iteration)

    if self.to_save_model:
          self.save_model()

    self.metrics.close

  def get_action(self, state, eval_mode=False):
    """
    if eval_mode, return the best action
    else, return the action based on the policy
    """
    if not isinstance(state, torch.Tensor):
        state = torch.tensor(np.array([state]), dtype=torch.float32, device=self.device)
    if len(state.shape) == 1:
        state = state.unsqueeze(0)

    with torch.no_grad():
      logits, _ = self.policy_net(state)
      dist = torch.distributions.Categorical(logits=logits)

      if eval_mode:
          action = torch.argmax(dist.probs, dim=1)
          action = action.item()
      else:
          action_sample = dist.sample()
          action = action_sample.item()

    return action


  def get_action_value(self, state, eval_mode=False):
    """
    return action log_prob and value
    """

    if not isinstance(state, torch.Tensor):
        state = torch.tensor(np.array([state]), dtype=torch.float32, device=self.device)
    if len(state.shape) == 1:
        state = state.unsqueeze(0)

    with torch.no_grad():
        logits, value = self.policy_net(state)
        dist = torch.distributions.Categorical(logits=logits)

        if eval_mode: # select the best action
            action_tensor = torch.argmax(dist.probs, dim=1)
        else: # sample action from policy
            action_tensor = dist.sample()
        log_prob = dist.log_prob(action_tensor)

    action = action_tensor.item()

    return action, log_prob, value.squeeze(1)


  def evaluate(self, env, episode_num, iteration = -1):
    total_reward = 0.0
    total_steps = 0
    r = tqdm(range(episode_num), desc="Evaluating Agent") if iteration == -1 else range(episode_num)

    for episode in r:
        state = env.reset()[0]
        done = False
        truncated = False

        episode_reward = []
        steps = 0

        while (not done and not truncated):
            # Select best action
            action, _, _ = self.get_action_value(state, eval_mode=True)
            next_state, reward, done, truncated, info = env.step(action)

            episode_reward.append(reward)
            steps += 1
            state = next_state
            env.render()

        total_reward += sum(episode_reward) / len(episode_reward)
        total_steps += steps

        if iteration == -1:
            self.metrics.add(f"evaluate/episode/episode_reward(episode_reward/episode)", sum(episode_reward) / len(episode_reward), episode)
            self.metrics.add(f"evaluate/episode/episode_steps(steps/episode)", steps, episode)

    avg_reward = total_reward / episode_num
    avg_steps = total_steps / episode_num

    if iteration != -1:
      self.metrics.add(f"evaluate/iteration/iteration_avg_reward(avg_reward/iteration)",avg_reward, iteration)
      self.metrics.add(f"evaluate/iteration/iteration_avg_steps(avg_steps/iteration)",avg_steps, iteration)

    if iteration == -1: print(f"\nEvaluation after training: Num of Episodes: {episode_num}, Avg Reward: {avg_reward:.2f}, Avg Steps: {avg_steps:.2f}")
    return avg_reward, avg_steps


  def create_folder(self, directory_name):
    try:
        os.mkdir(directory_name)
        print(f"Directory '{directory_name}' created successfully.")
    except FileExistsError:
        return
    except PermissionError:
        print(f"Permission denied: Unable to create '{directory_name}'.")
    except Exception as e:
        print(f"An error occurred: {e}")

  def save_model(self):
      folder_name = self.policy + "_save_models"
      self.create_folder(folder_name)
      new_model_num = str(len(os.listdir("./" +folder_name)) + 1)
      file_name = f'{folder_name}/PPO_{new_model_num}_{self.time}.pth'
      state = {
          "policy_net": self.policy_net.state_dict(),
          "optimizer": self.optimizer.state_dict()
      }
      torch.save(state, file_name)

      self.file_name = f"PPO_{new_model_num}_{self.time}"
      print(f"Model saved to {file_name}")

  def load_model(self, env, file_name):
      folder_name = self.policy + "_save_models"

      filename = folder_name + "/" + file_name + ".pth"
      self.create_network(env)

      models = torch.load(filename, map_location=self.device)

      self.policy_net.load_state_dict(models["policy_net"])
      self.optimizer.load_state_dict(models["optimizer"])

## Agent Initialization

In [4]:
config = {}
# choose a policy
policy = "CnnPolicy"
# policy = "MlpPolicy"

if policy == "CnnPolicy":
    config={
        "lanes_count" : 3,
        "observation": {
            "type": "GrayscaleObservation",
            "observation_shape": (128, 64),
            "stack_size": 4,
            "weights": [0.2989, 0.5870, 0.1140],  # weights for RGB conversion keep this conversion this is in the highway env page
            "scaling": 1.75,
        },
    }
else:
    config = {
        "lanes_count" : 3,
        "observation": {
            "type": "Kinematics",
            "vehicles_count": 5,
            "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"],
            "features_range": {
                "x": [-100, 100],
                "y": [-100, 100],
                "vx": [-20, 20],
                "vy": [-20, 20]
            },
            "absolute": False,
            "order": "sorted"
        },
    }

In [5]:
params = {
    'policy': policy,
    'device': torch.device("mps"),

    'num_iterations': 20,
    'num_steps': 1024,
    "step_reward" : 0.0,
    'gamma': 0.99,
    'lamda': 0.95,
    'clip_epsilon': 0.2,
    'loss_coeff': 0.5,
    'epochs': 10,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'entropy_coeff': 0.02,
    'use_metrics': True,
    'save_model': True,
    'save_params': True
}

## Agent Training

In [None]:
seed = 72 # Our group number
for i in range(1):
    seed += i
    torch.manual_seed(seed)
    ppo = PPOAgent(params)
    env = gym.make('highway-v0', render_mode='rgb_array', config=config)
    ppo.learn(env)
    ppo.evaluate(env, 10)
    print("\n")

## Agent Evaluation

In [None]:
env = gym.make('highway-v0', render_mode='rgb_array', config=config)
ppo = PPOAgent(params)
ppo.load_model(env, "DQN_1_20250101000000")
ppo.evaluate(env, 20)

## Run Tensorboard

In [10]:
%reload_ext tensorboard
%tensorboard --logdir training_results --host localhost --port 6013