In [2]:
#import gym
import gymnasium as gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

### The architecture of the Policy

In [3]:
env = gym.make("CartPole-v1")
print(f"Observation space {env.observation_space}")
print(f"Actions space {env.action_space}")

Observation space Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
Actions space Discrete(2)


In [4]:
class Policy(nn.Module):
    def __init__(self, s_size=4, h_size=16, a_size=2):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, a_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

    def act(self, state):
        # Convert state into torch tensor and add batch dimension
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

### Train the Agent using the Reinforce Algorithm

In [5]:
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)

def reinforce(n_episodes, max_t, gamma, print_every):
    scores_deque = deque(maxlen=100)
    scores = []
    for i_episode in range(1, n_episodes+1):
        saved_log_probs = []
        rewards = []
        state = env.reset()[0]
        # Generate an episode , save the logprobs of each action at each timestep
        for t in range(max_t):
            action, log_prob = policy.act(state)
            saved_log_probs.append(log_prob)
            new_state, reward, terminated, truncated, info = env.step(action)
            rewards.append(reward)
            state = new_state
            if terminated or truncated:
                break
        scores_deque.append(sum(rewards))
        scores.append(sum(rewards))

        discounts = [gamma**i for i in range(len(rewards)+1)]
        R = sum([a*b for a,b in zip(discounts, rewards)])

        policy_loss = []
        for log_prob in saved_log_probs:
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat(policy_loss).sum()

        # Preform gradient descent
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        if i_episode % print_every == 0:
            print(f"Episode {i_episode}, Average score : {np.mean(scores_deque)}")
        if np.mean(scores_deque) >= 195.0:
            print(f"Environement solved in {i_episode} , Average score : {np.mean(scores_deque)}")
            break
    return scores

In [6]:
scores = reinforce(n_episodes=2000, max_t=1000, gamma=1.0, print_every=100)

Episode 100, Average score : 30.78
Episode 200, Average score : 64.76
Episode 300, Average score : 56.7
Episode 400, Average score : 47.55
Episode 500, Average score : 91.37
Episode 600, Average score : 67.88
Episode 700, Average score : 71.82
Episode 800, Average score : 64.93
Episode 900, Average score : 59.16
Episode 1000, Average score : 99.3
Episode 1100, Average score : 73.44
Episode 1200, Average score : 56.0
Episode 1300, Average score : 71.92
Episode 1400, Average score : 86.04
Episode 1500, Average score : 73.25
Episode 1600, Average score : 85.89
Environement solved in 1653 , Average score : 196.21


### Watch the Agent play

In [8]:
from IPython import display
env = gym.make("CartPole-v1", render_mode="human")
state = env.reset()[0]
for t in range(1000):
    env.render()
    action, _ = policy.act(state)
    state, reward, done, _, _ = env.step(action)
    if done:
        break
env.close()