# Modules

## Install

In [None]:
!pip install gymnasium
!pip install swig
!pip install gymnasium[box2d]

## Import

In [1]:
import gymnasium as gym
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import base64, io

import numpy as np
from collections import deque, namedtuple

from gymnasium.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display 
import glob

# Environment

In [89]:
env = gym.make("CartPole-v1", render_mode="human")
STATE_SIZE = env.observation_space.shape[0]
ACTION_SIZE = env.action_space.n
FAR_LEFT_POSITION = env.observation_space.low[0]
X_RANGE = env.observation_space.high[0] - env.observation_space.low[0]

print('State shape: ', STATE_SIZE)
print('Number of actions: ', ACTION_SIZE)

State shape:  4
Number of actions:  2


In [90]:
env.observation_space.low[0], env.observation_space.high[0]

(-4.8, 4.8)

## Policy-Network

In [91]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_neurons, seed):
        super(PolicyNetwork, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(state_size, hidden_neurons)
        self.fc2 = nn.Linear(hidden_neurons, hidden_neurons)
        self.fc3 = nn.Linear(hidden_neurons, action_size)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        action_probs = F.softmax(self.fc3(x), dim=-1)
        return action_probs

## Replay Buffer

In [92]:
class ReplayBuffer:

    def __init__(self):
        self.memory = []

    def add(self, state, action, reward):
        self.memory.append((state, action, reward))

    def sample(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

## Agent

In [94]:
GAMMA = 0.99
LR = 1e-2
NEURONS = 64

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [95]:
class Agent:
    def __init__(self, state_size, action_size, hidden_neurons, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        self.policy_network = PolicyNetwork(state_size, action_size, hidden_neurons, seed).to(device)
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=LR)

    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action_probs = self.policy_network(state)
        action_probs = action_probs.cpu().data.numpy().squeeze()
        action = np.random.choice(np.arange(self.action_size), p=action_probs)
        return action

    def learn(self, states, actions, rewards, gamma):
        self.optimizer.zero_grad()

        action_probs = self.policy_network(states)
        log_probs = torch.log(action_probs)
        log_probs_selected = log_probs.gather(1, actions.unsqueeze(1)).squeeze()

        discounted_rewards = []
        running_reward = 0
        for reward in rewards[::-1]:
            running_reward = reward + gamma * running_reward
            discounted_rewards.insert(0, running_reward)
        discounted_rewards = torch.tensor(discounted_rewards).to(device)

        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-9)

        loss = -torch.mean(log_probs_selected * discounted_rewards)

        loss.backward()
        self.optimizer.step()

# Training

In [96]:
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_neurons=NEURONS, seed=0)

## Pretrained

In [86]:
agent.policy_network.load_state_dict(torch.load('./checkpoints/checkpoint-500.pth'))

<All keys matched successfully>

## Policy Gradient

In [97]:
def policy_gradient(agent, n_episodes=1000, max_t=200, gamma=0.99):

    scores = []                        
    scores_window = deque(maxlen=100)
    
    for i_episode in range(1, n_episodes+1):
        state, _ = env.reset()
        episode_states = []
        episode_actions = []
        episode_rewards = []
        
        for t in range(max_t):
            action = agent.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            episode_states.append(state)
            episode_actions.append(action)
            episode_rewards.append(reward)
            
            if done:
                break
            state = next_state
        
        scores_window.append(sum(episode_rewards))
        scores.append(sum(episode_rewards))
        
        states_tensor = torch.tensor(episode_states, dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(episode_actions, dtype=torch.int64).to(device)
        
        agent.learn(states_tensor, actions_tensor, episode_rewards, gamma)
        
        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
        if i_episode % 20 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
        if i_episode % 250 == 0:
            torch.save(agent.policy_network.state_dict(), f'./checkpoints/checkpoint-{i_episode}.pth')
            
    return scores

## Run

In [None]:
scores = policy_gradient(agent, n_episodes=1000)

# Simulation

## Random Agent

In [None]:
env = gym.make("CartPole-v1", render_mode="human")

while True:
    done = False
    state, _ = env.reset() 
    
    while not done:
        next_state, reward, terminated, truncated, _ = env.step(random.randint(0,ACTION_SIZE-1))
        done = terminated or truncated

## Learned Agent

In [None]:
env = gym.make("CartPole-v1", render_mode="human")

agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_neurons=NEURONS, seed=0)
agent.policy_network.load_state_dict(torch.load('./checkpoints/checkpoint-250.pth'))

while True:
    done = False
    state, _ = env.reset() 
    
    while not done:
        action = agent.act(state)
        state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated