Steps:
- import libraries
- setup hyperparameters
- define a policy (actor) network. Notice how this network is relatively small. Input dims is the shape of the state space. Output dim is the shape of the action space.
- define a value (critic) network.
- PPO is an on-policy algorithm, meaning it learns/updates the policy based on actions taken by the same policy. 
    - one might ask, why do you need a memory store for experience for an on policy algorithm? We are using it to:
        - perform batch updates
        - calculate (monte carlo) returns and advantages using GAE
- Run training

Next steps:
- Run inference
- Save video of one episode

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np

# Hyperparameters
learning_rate = 3e-4
gamma = 0.99
lmbda = 0.95
eps_clip = 0.2
K_epochs = 4
update_timestep = 2000

# Create the CartPole environment
env = gym.make('CartPole-v1')

# Define the policy network (actor)
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        action_probs = torch.softmax(self.fc3(x), dim=-1)
        return action_probs

# Define the value network (critic)
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        value = self.fc3(x)
        return value

# PPO Agent
class PPO:
    def __init__(self, state_dim, action_dim):
        self.policy = PolicyNetwork(state_dim, action_dim)
        self.policy_old = PolicyNetwork(state_dim, action_dim)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
        
        self.value_function = ValueNetwork(state_dim)
        self.vf_optimizer = optim.Adam(self.value_function.parameters(), lr=learning_rate)
        
        self.MseLoss = nn.MSELoss()
        self.action_dim = action_dim
    
    def select_action(self, state):
        state = torch.FloatTensor(state)
        action_probs = self.policy_old(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)
    
    def update(self, memory):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + gamma * discounted_reward
            rewards.insert(0, discounted_reward)
        
        # Normalize rewards
        rewards = torch.tensor(rewards)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        
        # Convert list to tensor
        old_states = torch.FloatTensor(memory.states)
        old_actions = torch.LongTensor(memory.actions)
        old_logprobs = torch.FloatTensor(memory.logprobs)

        
        # Optimize policy for K epochs
        for _ in range(K_epochs):
            # Compute the advantage
            values = self.value_function(old_states).squeeze()
            advantages = rewards - values.detach()

            # Get action probabilities from the current policy
            action_probs = self.policy(old_states)
            dist = Categorical(action_probs)
            new_logprobs = dist.log_prob(old_actions)
            dist_entropy = dist.entropy()

            # Importance sampling ratio
            ratios = torch.exp(new_logprobs - old_logprobs)

            # Compute surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            # Optimize policy
            self.optimizer.zero_grad()
            policy_loss.backward()
            self.optimizer.step()

            # Value loss
            value_loss = self.MseLoss(values, rewards)

            # Optimize value function
            self.vf_optimizer.zero_grad()
            value_loss.backward()
            self.vf_optimizer.step()
        
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

# Memory to store experiences
class Memory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
    
    def clear_memory(self):
        del self.states[:]
        del self.actions[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

# Main training loop
def train_ppo():
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    ppo = PPO(state_dim, action_dim)
    memory = Memory()

    timestep = 0
    for i_episode in range(10000):
        state, _ = env.reset()
        for t in range(1, 10000):
            timestep += 1

            # Select action using policy
            action, logprob = ppo.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            
            # Store in memory
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(logprob)
            memory.rewards.append(reward)
            memory.is_terminals.append(done)
            
            # Update state
            state = next_state

            # If episode is done, break the loop
            if done:
                break

            # Update PPO every update_timestep
            if timestep % update_timestep == 0:
                ppo.update(memory)
                memory.clear_memory()
                timestep = 0

        # Print episode results
        if i_episode % 500 == 0:
            print(f'Episode {i_episode}\tLength: {t}')

if __name__ == '__main__':
    train_ppo()


Episode 0	Length: 15
Episode 500	Length: 26
Episode 1000	Length: 22
Episode 1500	Length: 37
Episode 2000	Length: 72
Episode 2500	Length: 377
Episode 3000	Length: 9999
