In [None]:
! pip install swig "gymnasium[box2d]"

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# **IMPORTING LIBRARIES**

In [None]:
import torch
import random
import os
import torch.nn as nn
import numpy as np
import time
from collections import deque
import wandb
from tqdm import tqdm

import gymnasium as gym
import ale_py
# from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import RecordVideo

In [None]:
# os.environ["MUJOCO_GL"] = "egl"

In [None]:
# def create_environment(cfgs, eval = False):

#   def _init():
#       env = gym.make( id=cfgs.id , render_mode="rgb_array", max_episode_steps=cfg.max_steps)
#       return env

#   return _init

In [None]:
def create_environment(cfgs, eval = False):
  env = gym.make( id=cfgs.id , render_mode="rgb_array")
  return env

# **WANDB RUN**

In [None]:
def wandb_runs(cfg):

  wandb.login(key = "")
  run = wandb.init(
    entity="ajheshbasnet-kpriet",
    project="ddpg",
    name = "DDPG",
    config=vars(cfg),
  )

  return run

# **CONFIGURATIONS**

In [None]:
from dataclasses import dataclass

@dataclass
class configuration:
  id = "LunarLander-v3"
  n_rollouts = 25_000
  eval_steps = 10
  global_steps = 0
  eval_loops = 3
  batch_size = 96
  ppo_r_clamp = 0.2
  critic_lr = 1e-4
  actor_lr = 1e-4
  record_video = 500_000
  global_eval_steps = 0
  device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = configuration()

**SyncVectorEnv so that we can run the n-environments parrallelly and utilize the GPUs because single environment is wayy poor**

In [None]:
# envs = SyncVectorEnv([create_environment(cfg) for _ in range(cfg.n_envs)])

envs = create_environment(cfg)

In [None]:
envs.reset()[0]

**Checking environment is working or not:)**

# **Actor and Critic Netowrk**

In [None]:
class Actor(nn.Module):

  def __init__(self, input_dim, action_dim):
    super().__init__()
    self.sequential = nn.Sequential(
        nn.Linear(input_dim, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, action_dim),
    )

  def forward(self, x):
    x = self.sequential(x)
    return x

  def get_log_probs_action(self, x, actions = None):
    logits = self(x)

    dist = torch.distributions.Categorical(logits=logits)

    if actions == None:
      action = dist.sample()

    else:
      action = actions

    log_prob = dist.log_prob(action)
    entropy = dist.entropy().mean()

    return action, log_prob, entropy

In [None]:
class Critic(nn.Module):

  def __init__(self, input_dim):
    super().__init__()

    self.sequential = nn.Sequential(
        nn.Linear(input_dim, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 1)
    )

  def forward(self, x):
    x = self.sequential(x)
    return x

In [None]:
actornet = Actor(envs.observation_space.shape[0], envs.action_space.n).to(cfg.device)  #type: ignore
criticnet = Critic(envs.observation_space.shape[0]).to(cfg.device)  #type: ignore

In [None]:
print(f'''Parameters:
==============================
actor-network     : {sum(p.numel() for p in actornet.parameters())/1e3} k
critic-network(s) : {sum(p.numel() for p in criticnet.parameters())/ 1e3} k
==============================
      ''')

**Evaluation Loop**

In [None]:
def evaluation(actornet, record_video = False):

  eval_env = gym.make(id = cfg.id, render_mode = 'rgb_array')
  if record_video:
    video_dir = f"videos/{int(time.time())}"
    eval_env = RecordVideo(eval_env,  video_folder=video_dir, episode_trigger=lambda ep: True)

  net_reward = 0
  net_step = 0

  with torch.no_grad():

    for _ in range(cfg.eval_loops):

      done = False

      episodic_reward = 0
      episodic_step = 0
      state = eval_env.reset()[0]

      while not done:

        stateT = torch.tensor(state, dtype=torch.float32, device=cfg.device)
        action = actornet(stateT).argmax().item()
        nxt_state, reward, terminated, truncated, _ = eval_env.step(action)
        done = terminated or truncated
        state = nxt_state

        episodic_reward += float(reward)
        episodic_step += 1
      print(episodic_reward)
      net_reward += episodic_reward
      net_step  += episodic_step


  net_reward = net_reward / cfg.eval_loops
  net_step = net_step / cfg.eval_loops

  eval_env.close()

  return net_reward, net_step

In [None]:
evaluation(actornet, True)

**To sample the batches**

In [None]:
actor_optimizer = torch.optim.AdamW(actornet.parameters(), lr = cfg.actor_lr)
critic_optimizer = torch.optim.AdamW(criticnet.parameters(), lr = cfg.critic_lr)

**W&B RUNS TO LOG THE METRICS**

In [None]:
runs = wandb_runs(cfg)

In [None]:
gamma = 0.99
lambda_ = 0.96
entropy_beta = 0.001

# **Heart & Core of the notebook: PPO Algorithm's Training Loop**

In [None]:
for _ in tqdm(range(cfg.n_rollouts)):

  done = False

  states = []
  actions = []
  next_states = []
  rewards = []
  log_probs = []
  dones = []

  state = envs.reset()[0]

  stateT = torch.tensor(state, dtype=torch.float32, device = cfg.device)
  training_reward = 0
  training_step = 0

  while not done:

    with torch.no_grad():

      action, logprob, _ = actornet.get_log_probs_action(stateT, None)

    next_state, reward, terminated, truncated, _ = envs.step(action=action.item())

    done = terminated or truncated

    next_stateT = torch.tensor(next_state, dtype=torch.float32, device = cfg.device)
    rewardT = torch.tensor(reward, dtype=torch.float32, device = cfg.device)
    doneT = torch.tensor(done, dtype=torch.float32, device = cfg.device)

    states.append(stateT)
    actions.append(action)
    next_states.append(next_stateT)
    log_probs.append(logprob)
    rewards.append(rewardT)
    dones.append(doneT)

    stateT = next_stateT
    training_reward += float(reward)
    training_step += 1
    cfg.global_steps += 1


  runs.log({"training-reward": training_reward, "training-step": training_step, "global-steps": cfg.global_steps})

  all_states = torch.stack(states)
  all_actions = torch.stack(actions).view(-1, 1)
  all_next_states = torch.stack(next_states)
  all_rewards = torch.stack(rewards).view(-1, 1)
  all_dones = torch.stack(dones).view(-1, 1)
  old_log_probs = torch.stack(log_probs).view(-1, 1)

    # ----- Compute Values -----
  with torch.no_grad():
      values = criticnet(all_states).view(-1, 1)           # [T, 1]
      next_value = criticnet(all_next_states[-1]).view(1, 1)  # [1, 1]

  # ----- GAE -----
  T = all_rewards.size(0)

  advantages = torch.zeros_like(all_rewards)
  gae = 0

  for t in reversed(range(T)):
      if t == T - 1:
          next_val = next_value
      else:
          next_val = values[t + 1]

      delta = all_rewards[t] + gamma * (1 - all_dones[t]) * next_val - values[t]

      gae = delta + gamma * lambda_ * (1 - all_dones[t]) * gae
      advantages[t] = gae

  returns = advantages + values
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)


  log_actor_loss = 0
  log_policy_loss = 0
  log_critic_loss = 0
  log_entropy = 0
  log_advantage = 0
  step = 0

  batch_rollout = int(all_next_states.size(0) / cfg.batch_size)

  shuffled_ids = torch.randperm(all_states.size(0), device=cfg.device)

  batch_size = min(cfg.batch_size, all_states.size(0))

  for i in range(0, all_actions.size(0), batch_size):

    batch_ids = shuffled_ids[i: i+batch_size]

    mb_states = all_states[batch_ids]
    mb_rewards = all_rewards[batch_ids].squeeze(-1)
    mb_actions = all_actions[batch_ids].squeeze(-1)
    mb_advantages = advantages[batch_ids].squeeze(-1)
    mb_returns = returns[batch_ids].squeeze(-1)
    mb_old_log_probs = old_log_probs[batch_ids].squeeze(-1)

    _, mb_new_log_probs, entropy = actornet.get_log_probs_action(mb_states, mb_actions)

    ratio = torch.exp(mb_new_log_probs - mb_old_log_probs)

    policy_loss_ = -torch.min(ratio * mb_advantages, torch.clamp(ratio, 1-cfg.ppo_r_clamp, 1+cfg.ppo_r_clamp)*mb_advantages).mean()

    policy_loss = policy_loss_ - entropy_beta * entropy

    critic_loss = torch.nn.functional.mse_loss(criticnet(mb_states).squeeze(-1), mb_returns)

    log_actor_loss += policy_loss.item()
    log_critic_loss += critic_loss.item()
    log_entropy += entropy.item()
    log_policy_loss += policy_loss_.item()
    log_advantage += mb_advantages.mean().item()
    step += 1

    actor_optimizer.zero_grad()
    policy_loss.backward()
    actor_optimizer.step()

    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

  log_actor_loss = log_actor_loss / step
  log_policy_loss = log_policy_loss / step
  log_critic_loss = log_critic_loss / step
  log_entropy = log_entropy / step

  runs.log({"policy-loss": log_policy_loss, "actor-loss": log_actor_loss, "critic-loss": log_critic_loss, "entropy": log_entropy, "advantage": log_advantage})

  if cfg.global_eval_steps%cfg.eval_steps==0 and cfg.global_eval_steps>1:
    rec_frame = True if cfg.global_steps% cfg.record_video == 0 else False
    net_reward, net_step = evaluation(actornet, rec_frame)
    runs.log({"eval-reward": net_reward, "eval-steps": net_step})
  cfg.global_eval_steps += 1

# **SAVE MODEL-WEIGHTS**

In [None]:
torch.save(actornet.state_dict(), "policy-weights.pt")