In [1]:
import torch
import ale_py
import wandb
import torch.nn as nn
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation, RecordVideo
from gymnasium.vector import SyncVectorEnv
import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm import tqdm
import random

In [2]:
wandb.login(
    key = ""
)

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33majheshbasnet[0m ([33majheshbasnet-kpriet[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
class configs:

  game_id = "RiverraidNoFrameskip-v4"
  max_step = 4000
  stack_size = 4
  n_episodes = 50_000
  policy_lr = 3e-3
  value_lr = 2.5e-3
  discount_factor = 0.99
  epsilon = 0.15
  rollouts = 5_000
  eval_steps = 20000
  cam_counter = 100_000
  num_envs = 6
  eval_loops = 3
  ppo_epochs = 5
  minibatch_size = 16
  device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = configs()

In [4]:
def create_run(configs):
    return wandb.init(
    name = "ppo",
    project="ppo",
    # Track hyperparameters and run metadata.
    config=vars(configs)
    )

In [5]:
def createEnvironment(cfg):

  def _init():

    env = gym.make(cfg.game_id, frameskip = 1, full_action_space=False, render_mode="rgb_array", max_episode_steps=configs.max_step)

    env = AtariPreprocessing(env, frame_skip=4, grayscale_obs=True, screen_size = 84)
    # "scale_obs" means the pixels are scaled/normalised from 0 to 1 else it's in uint8 number--> keeping it False because to store it the float32 takes way huge memory so the training will be too much slow around 11s/iteration. Hence do it during the run time only.

    env = FrameStackObservation(env, cfg.stack_size)
    # it gives [frame(t-3), frame(t-2), frame(t-1), frame(t)] NOT [frame(t), frame(t+1), frame(t+2), frame(t+3)]

    # during env.reset() it gives obs = stack of [obs, obs, obs, obs] which is the same frame during the first time
    # so after the 1st action the stack becomes [f0, f0, f0, f1] and after another action it becomes [f0, f0, f1, f2] and so on.

    return env

  return _init


def evalenv(cfg):

    env = gym.make(cfg.game_id, frameskip = 1, full_action_space=False, render_mode="rgb_array", max_episode_steps=configs.max_step)

    env = AtariPreprocessing(env, grayscale_obs=True, screen_size = 84)
    # "scale_obs" means the pixels are scaled/normalised from 0 to 1 else it's in uint8 number--> keeping it False because to store it the float32 takes way huge memory so the training will be too much slow around 11s/iteration. Hence do it during the run time only.

    env = FrameStackObservation(env, cfg.stack_size)
    # it gives [frame(t-3), frame(t-2), frame(t-1), frame(t)] NOT [frame(t), frame(t+1), frame(t+2), frame(t+3)]

    # during env.reset() it gives obs = stack of [obs, obs, obs, obs] which is the same frame during the first time
    # so after the 1st action the stack becomes [f0, f0, f0, f1] and after another action it becomes [f0, f0, f1, f2] and so on.

    return env

In [6]:
env = SyncVectorEnv([createEnvironment(cfg) for _ in range(cfg.num_envs)])

In [7]:
class PolicyNetwork(nn.Module):

  def __init__(self, action_space):

    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(cfg.stack_size, 32, kernel_size=5, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=4, stride=3),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, stride=1),
        nn.ReLU()
    )

    self.ffnn = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 64),
        nn.ReLU(),
        nn.Linear(64, action_space)
    )

  def forward(self, x):
    x = self.conv(x/255.)
    x = x.view(x.size(0), -1)
    x = self.ffnn(x)
    return x

  def log_probs(self, x, pick_action):

    action_probs = torch.nn.functional.softmax(self(x), dim = -1)

    if pick_action ==  None:
      action_idx = torch.multinomial(action_probs, 1)
      log_prob = torch.gather(action_probs, -1, action_idx).log()
      return action_idx, log_prob

    if pick_action != None:
      log_prob = torch.gather(action_probs, -1, pick_action).log()
      return log_prob

In [8]:
class Value_Network(nn.Module):

  def __init__(self):

    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(cfg.stack_size, 32, kernel_size=5, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=4, stride=3),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, stride=1),
        nn.ReLU()
    )

    self.ffnn = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 64),
        nn.ReLU(),
        nn.Linear(64, 1)
    )

  def forward(self, x):
    x = self.conv(x/255.)
    x = x.view(x.size(0), -1)
    x = self.ffnn(x)
    return x

In [9]:
currentPolicy = PolicyNetwork(env.single_action_space.n).to(cfg.device)
ValueNetwork = Value_Network().to(cfg.device)

In [10]:
print(f'''
=======================================================================
> Actor-Net:  {sum(p.numel() for p in currentPolicy.parameters())/1e3 :.2f} k
> Policy-Net: {sum(p.numel() for p in ValueNetwork.parameters())/1e3 :.2f} k
-----------------------------------------------------------------------
> {cfg.device.upper()} is being used
=======================================================================
''')


> Actor-Net:  177.84 k
> Policy-Net: 176.74 k
-----------------------------------------------------------------------
> CUDA is being used



In [11]:
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)

In [12]:
def evaluationLoop(policynetwork, recordVideo = False):

  eval_env = evalenv(cfg)

  if recordVideo:
    eval_env = RecordVideo(
                           eval_env, video_folder="videos/",
                           episode_trigger=lambda episode_id: True, name_prefix="ppo"
                           )

  total_eval_rewards = 0
  total_eval_steps = 0

  policynetwork.eval()
  with torch.no_grad():

    for _ in range(configs.eval_loops):

      obs, _ = eval_env.reset()
      done = False

      ep_reward = 0.0
      ep_step = 0

      while not done:

        action = policynetwork(torch.tensor(obs, dtype=torch.long).unsqueeze(0).to(cfg.device)).argmax().item()
        next_obs,reward, terminated, truncated, _ =  eval_env.step(action)
        obs = next_obs
        ep_reward += float(reward)
        done = terminated or truncated
        ep_step += 1

      total_eval_rewards += ep_reward
      total_eval_steps += ep_step

    total_eval_rewards = total_eval_rewards / cfg.eval_loops
    total_eval_steps = int(total_eval_steps / cfg.eval_loops)

  eval_env.close()
  policynetwork.train()
  return total_eval_rewards, total_eval_steps

In [13]:
policy_optimizer = torch.optim.Adam(currentPolicy.parameters(), cfg.policy_lr)
value_optimizer = torch.optim.Adam(ValueNetwork.parameters(), cfg.value_lr)

In [None]:
runs = create_run(cfg)
global_step = 0

ValueNetwork.train()
currentPolicy.train()

for steps in range(cfg.n_episodes):

    buffer = {
        "states": [],
        "rewards": [],
        "actions": [],
        "log_probs": [],
        "terminated": [],
        "truncated": []
    }

    states = env.reset()[0]
    training_rewards = torch.zeros((cfg.num_envs,), device=cfg.device)

    # ===================== ROLLOUT =====================
    for rollouts in range(cfg.rollouts):

        state_tensor = torch.from_numpy(states).to(cfg.device)

        with torch.no_grad():
            action, log_probs = currentPolicy.log_probs(state_tensor, None)
            action = action.cpu().numpy().reshape(-1)

        next_state, reward, terminated, truncated, _ = env.step(action)

        training_rewards += torch.tensor(reward, dtype=torch.float32, device=cfg.device)

        buffer["states"].append(state_tensor.cpu())
        buffer["actions"].append(torch.tensor(action, dtype=torch.long))
        buffer["rewards"].append(torch.tensor(reward, dtype=torch.float32))
        buffer["log_probs"].append(log_probs.view(-1).cpu())
        buffer["terminated"].append(torch.tensor(terminated, dtype=torch.bool))
        buffer["truncated"].append(torch.tensor(truncated, dtype=torch.bool))

        states = next_state
        global_step += 1

        runs.log({"global-step": global_step})

        if global_step % cfg.eval_steps == 0:
            rec = global_step % cfg.cam_counter == 0
            eval_rewards, eval_steps = evaluationLoop(currentPolicy, rec)
            runs.log({"eval-rewards": eval_rewards, "eval-steps": eval_steps})

    # ===================== FLATTEN =====================
    all_states = torch.stack(buffer["states"]).permute(1, 0, 2, 3, 4).reshape(-1, 4, 84, 84).to(cfg.device)
    all_actions = torch.stack(buffer["actions"]).permute(1, 0).reshape(-1).to(cfg.device)
    all_rewards = torch.stack(buffer["rewards"]).permute(1, 0).reshape(-1).to(cfg.device)
    old_log_prob = torch.stack(buffer["log_probs"]).permute(1, 0).reshape(-1, 1).to(cfg.device)
    all_terminated = torch.stack(buffer["terminated"]).permute(1, 0).reshape(-1)
    all_truncated = torch.stack(buffer["truncated"]).permute(1, 0).reshape(-1)

    # ===================== RETURNS =====================
    Vt = ValueNetwork(all_states).view(-1, 1)
    R = []
    Gt = torch.zeros(1, device=cfg.device)
    T = all_states.size(0)
    idx = 0

    for r, term, trunc in zip(
        all_rewards.flip(0),
        all_terminated.flip(0),
        all_truncated.flip(0)
    ):
        idx += 1
        if trunc.item():
            Gt = Vt[T - idx].detach()
        if term.item():
            Gt = torch.zeros(1, device=cfg.device)

        Gt = r + cfg.discount_factor * Gt
        R.insert(0, Gt)

    Rt = torch.stack(R)

    # ===================== ADVANTAGE =====================
    At = (Rt - Vt).detach()
    At = (At - At.mean()) / (At.std() + 1e-8)

    # ===================== PPO UPDATE =====================
    N = all_states.size(0)

    for _ in range(cfg.ppo_epochs):

        indices = torch.randperm(N)

        for start in range(0, N, cfg.ppo_epochs):
            end = start + cfg.minibatch_size
            mb_idx = indices[start:end]

            mb_states   = all_states[mb_idx]
            mb_actions  = all_actions[mb_idx]
            mb_old_logp = old_log_prob[mb_idx]
            mb_adv      = At[mb_idx]
            mb_returns  = Rt[mb_idx]

            mb_new_logp = currentPolicy.log_probs(
                mb_states, mb_actions.reshape(-1, 1)
            ).reshape(-1, 1)

            ratio = torch.exp(mb_new_logp - mb_old_logp)

            policy_loss = -torch.mean(
                torch.min(
                    ratio * mb_adv,
                    torch.clamp(ratio, 1 - cfg.epsilon, 1 + cfg.epsilon) * mb_adv
                )
            )

            value_pred = ValueNetwork(mb_states).view(-1, 1)
            value_loss = torch.nn.functional.mse_loss(value_pred, mb_returns)

            policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_optimizer.step()

            value_optimizer.zero_grad()
            value_loss.backward()
            value_optimizer.step()

    runs.log({
        "policy-loss": policy_loss.item(),
        "value-loss": value_loss.item(),
        "training-rewards": training_rewards.mean().item()
    })
    del all_states, all_rewards, all_terminated, all_truncated, old_log_prob
    del policy_loss, value_loss
runs.finish()
env.close()


  logger.warn(


In [None]:
mb_states.size(), mb_actions.size()

In [None]:
currentPolicy.log_probs(all_states, all_actions)

In [None]:
torch.randperm(5)


In [None]:
all_states.size()