<a href="https://colab.research.google.com/github/NoobCoder-dweeb/AI-HandsOn-Journey/blob/main/notes/PPO_using_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install -q torch==1.13.0 swig gym==0.18.0

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.


In [34]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

class PPOMemory:
  def __init__(self, batch_size):
    self.states = []
    self.probs = []
    self.vals = []
    self.actions = []
    self.rewards = []
    self.dones = []

    self.batch_size = batch_size

  def generate_batches(self):
    n_states = len(self.states)
    batch_start = np.arange(0, n_states, self.batch_size)
    indices = np.arange(n_states, dtype=np.int32)
    np.random.shuffle(indices)
    batches = [indices[i:i+self.batch_size] for i in batch_start]

    return np.array(self.states),\
           np.array(self.actions),\
           np.array(self.probs),\
           np.array(self.vals),\
           np.array(self.rewards),\
           np.array(self.dones),\
           batches

  def store_memory(self, state, action, probs, vals, reward, done):
    self.states.append(state)
    self.actions.append(action)
    self.probs.append(probs)
    self.vals.append(vals)
    self.rewards.append(reward)
    self.dones.append(done)

  def clear_memory(self):
    self.states = []
    self.probs = []
    self.vals = []
    self.actions = []
    self.rewards = []


class ActorNetwork(nn.Module):
  def __init__(self, n_actions, input_dims, alpha,
               fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
    super(ActorNetwork, self).__init__()
    self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
    self.actor = nn.Sequential(
        nn.Linear(*input_dims, fc1_dims),
        nn.ReLU(),
        nn.Linear(fc1_dims, fc2_dims),
        nn.ReLU(),
        nn.Linear(fc2_dims, n_actions),
        nn.Softmax(dim=-1)
    )
    self.optimizer = optim.Adam(self.parameters(), lr=alpha)
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.to(self.device)

  def forward(self, state):
    dist = self.actor(state)
    dist = Categorical(dist)
    return dist

  def save_checkpoint(self):
    torch.save(self.state_dict(), self.checkpoint_file)

  def load_checkpoint(self):
    self.load_state_dict(torch.load(self.checkpoint_file))

class CriticNetwork(nn.Module):
  def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
    super(CriticNetwork, self).__init__()

    self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
    self.critic = nn.Sequential(
        nn.Linear(*input_dims, fc1_dims),
        nn.ReLU(),
        nn.Linear(fc1_dims, fc2_dims),
        nn.ReLU(),
        nn.Linear(fc2_dims, 1)
    )
    self.optimizer = optim.Adam(self.parameters(), lr=alpha)
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.to(self.device)

  def forward(self, state):
    return self.critic(state)

  def save_checkpoint(self):
    torch.save(self.state_dict(), self.checkpoint_file)

  def load_checkpoint(self):
    self.load_state_dict(torch.load(self.checkpoint_file))


class Agent:
  def __init__(self, input_dims, n_actions, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
               policy_clip=0.2, batch_size=64, N=2048, n_epochs=10):
    self.gamma = gamma
    self.policy_clip = policy_clip
    self.n_epochs = n_epochs
    self.gae_lambda = 0.95
    self.batch_size = batch_size

    self.actor = ActorNetwork(n_actions, input_dims, alpha)
    self.critic = CriticNetwork(input_dims, alpha)
    self.memory = PPOMemory(batch_size)

  def remember(self, state, action, probs, vals, reward, done):
    self.memory.store_memory(state, action, probs, vals, reward, done)

  def save_models(self):
    print("Saving....")
    self.actor.save_checkpoint()
    self.critic.save_checkpoint()

  def load_models(self):
    print("Loading....")
    self.actor.load_checkpoint()
    self.critic.load_checkpoint()

  def choose_action(self, observation):
    state = torch.tensor([observation], dtype=torch.float).to(self.actor.device)
    dist = self.actor(state)
    value = self.critic(state)
    action = dist.sample()

    probs = torch.squeeze(dist.log_prob(action)).item()
    action = torch.squeeze(action).item()
    value = torch.squeeze(value).item()

    return action, probs, value

  def learn(self):
    for _ in range(self.n_epochs):
      state_arr, action_arr, old_probs_arr, vals_arr, \
      reward_arr, dones_arr, batches = \
        self.memory.generate_batches()

      values = vals_arr
      advantage = np.zeros(len(reward_arr), dtype=np.float32)

      for t in range(len(reward_arr) - 1):
        discount = 1
        a_t = 0
        for k in range(t, len(reward_arr) - 1):
          a_t += discount * (reward_arr[k] + self.gamma * values[k + 1] * (1 - int(dones_arr[k])) - values[k])
          discount *= self.gamma * self.gae_lambda
        advantage[t] = a_t

      advantage = torch.tensor(advantage).to(self.actor.device)
      values = torch.tensor(values).to(self.actor.device)

      for batch in batches:
        states = torch.tensor(state_arr[batch], dtype=torch.float).to(self.actor.device)
        old_probs = torch.tensor(old_probs_arr[batch], dtype=torch.float).to(self.actor.device)
        actions = torch.tensor(action_arr[batch]).to(self.actor.device)

        dist = self.actor(states)
        critic_value = self.critic(states)

        critic_value = torch.squeeze(critic_value)

        new_probs = dist.log_prob(actions)
        prob_ratio = new_probs.exp() / old_probs.exp()
        # prob_ratio = (new_probs - old_probs).exp()
        weighted_probs = advantage[batch] * prob_ratio
        weighted_clipped_probs = torch.clamp(prob_ratio, 1 - self.policy_clip, 1 + self.policy_clip) \
        * advantage[batch]
        actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()

        returns = advantage[batch] + values[batch]
        critic_loss = (returns - critic_value) ** 2
        critic_loss = critic_loss.mean()
        total_loss = actor_loss + 0.5*critic_loss
        self.actor.optimizer.zero_grad()
        total_loss.backward()
        self.actor.optimizer.step()
        self.critic.optimizer.step()

      self.memory.clear_memory()


import gym

if __name__ == '__main__':
  env = gym.make("CartPole-v1")
  N = 20
  batch_size = 5
  n_epochs = 4
  alpha = 0.0003
  agent = Agent(n_actions=env.action_space.n, batch_size=batch_size,
                alpha=alpha, n_epochs=n_epochs, input_dims=env.observation_space.shape)
  n_games = 300

  best_score = env.reward_range[0]
  score_history = []

  learn_iters = 0
  avg_score = 0
  n_steps = 0

  for i in range(n_games):
    observation = env.reset()
    done = False
    score = 0
    while not done:
      action, prob, val = agent.choose_action(observation)
      observation_, reward, done, info = env.step(action)
      n_steps += 1
      score += reward
      agent.remember(observation, action, prob, val, reward, done)
      if n_steps % N == 0:
        agent.learn()
        learn_iters += 1
      observation = observation_
    score_history.append(score)
    avg_score = np.mean(score_history[-100:])

    if avg_score > best_score:
      best_score = avg_score
      agent.save_models()

    print('episode', i, f'score {score:.1f}', f"avg score {avg_score: .1f}", \
          "time steps", n_steps, "learning_steps", learn_iters)

Saving....
episode 0 score 41.0 avg score  41.0 time steps 41 learning_steps 2
episode 1 score 34.0 avg score  37.5 time steps 75 learning_steps 3
episode 2 score 15.0 avg score  30.0 time steps 90 learning_steps 4
episode 3 score 23.0 avg score  28.2 time steps 113 learning_steps 5
episode 4 score 36.0 avg score  29.8 time steps 149 learning_steps 7
episode 5 score 16.0 avg score  27.5 time steps 165 learning_steps 8
episode 6 score 11.0 avg score  25.1 time steps 176 learning_steps 8
episode 7 score 11.0 avg score  23.4 time steps 187 learning_steps 9
episode 8 score 22.0 avg score  23.2 time steps 209 learning_steps 10
episode 9 score 19.0 avg score  22.8 time steps 228 learning_steps 11
episode 10 score 32.0 avg score  23.6 time steps 260 learning_steps 13
episode 11 score 9.0 avg score  22.4 time steps 269 learning_steps 13
episode 12 score 18.0 avg score  22.1 time steps 287 learning_steps 14
episode 13 score 16.0 avg score  21.6 time steps 303 learning_steps 15
episode 14 score 