#Setup

Install Dependencies

In [1]:
!pip install gymnasium



Import dependencies

In [2]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import time
import matplotlib.pyplot as plt
import tqdm as tqdm


#Preperation

Define the Policy Network

In [3]:
class Policy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)

Create the environment, instantiate the policy network and define the optimizer

In [4]:
# Create the environment
env = gym.make('CartPole-v1',render_mode="rgb_array")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Initialize the policy network
policy = Policy(state_dim, action_dim)

# Define the optimizer
optimizer = optim.Adam(policy.parameters(), lr=0.01)

#Algorithm

Pick an action based on policy

In [5]:
def select_action(state):
    state = np.array(state)
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

Policy Gradiant Algorithm, the actual training loop

In [6]:
def policy_gradient():
    num_episodes = 1000000
    gamma = 0.99

    rewards_per_episode = []  # List to store rewards for each episode
    policy_losses = []  # List to store policy loss for each episode

    # for 10000 episodes
    for episode in range(num_episodes):
        observations = env.reset()
        state = np.array(observations[0])
        episode_reward = 0
        log_probs = []
        rewards = []

        # loop through each time step in one episode
        while True:
            action, log_prob = select_action(state)
            next_state, reward, done, truncated, _ = env.step(action)

            log_probs.append(log_prob)
            rewards.append(reward)
            episode_reward += reward

            if done or truncated:
                break

            state = next_state


        # Compute the discounted rewards
        discounts = [gamma**i for i in range(len(rewards))]
        discounted_rewards = [discount * reward for discount, reward in zip(discounts, rewards)]
        
        # Convert the discounted_rewards into a Tensor
        discounted_rewards = torch.Tensor(discounted_rewards)

        # Normalize the discounted rewards
        discounted_rewards -= torch.mean(discounted_rewards)
        discounted_rewards /= torch.std(discounted_rewards)

        # Calculate the loss
        policy_loss = []
        for log_prob, reward in zip(log_probs, discounted_rewards):
            policy_loss.append(-log_prob * reward)
        policy_loss = torch.cat(policy_loss).sum()

        # Update the policy network
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        # Print the episode statistics
        if episode % 100 == 0:
            print('Episode {}: reward = {}'.format(episode, episode_reward))
        
        rewards_per_episode.append(episode_reward)
        policy_losses.append(policy_loss.item())
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(rewards_per_episode)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Reward per Episode')

    plt.subplot(1, 2, 2)
    plt.plot(policy_losses)
    plt.xlabel('Episode')
    plt.ylabel('Policy Loss')
    plt.title('Policy Loss per Episode')

    plt.tight_layout()
    plt.show()

#Run Trials

In [None]:
policy_gradient()

Episode 0: reward = 27.0
Episode 100: reward = 9.0
Episode 200: reward = 103.0
Episode 300: reward = 68.0
Episode 400: reward = 139.0
Episode 500: reward = 9.0
Episode 600: reward = 10.0
Episode 700: reward = 10.0
Episode 800: reward = 9.0
Episode 900: reward = 9.0
Episode 1000: reward = 9.0
Episode 1100: reward = 10.0
Episode 1200: reward = 9.0
Episode 1300: reward = 10.0
Episode 1400: reward = 9.0
Episode 1500: reward = 10.0
Episode 1600: reward = 9.0
Episode 1700: reward = 10.0
Episode 1800: reward = 9.0
Episode 1900: reward = 9.0
Episode 2000: reward = 10.0
Episode 2100: reward = 10.0
Episode 2200: reward = 10.0
Episode 2300: reward = 9.0
Episode 2400: reward = 10.0
Episode 2500: reward = 9.0
Episode 2600: reward = 10.0
Episode 2700: reward = 8.0
Episode 2800: reward = 10.0
Episode 2900: reward = 10.0
Episode 3000: reward = 9.0
Episode 3100: reward = 10.0
Episode 3200: reward = 500.0
Episode 3300: reward = 240.0
Episode 3400: reward = 122.0
Episode 3500: reward = 85.0
Episode 3600:

Episode 28000: reward = 171.0
Episode 28100: reward = 164.0
Episode 28200: reward = 173.0
Episode 28300: reward = 163.0
Episode 28400: reward = 171.0
Episode 28500: reward = 167.0
Episode 28600: reward = 217.0
Episode 28700: reward = 226.0
Episode 28800: reward = 237.0
Episode 28900: reward = 200.0
Episode 29000: reward = 183.0
Episode 29100: reward = 177.0
Episode 29200: reward = 156.0
Episode 29300: reward = 171.0
Episode 29400: reward = 158.0
Episode 29500: reward = 182.0
Episode 29600: reward = 182.0
Episode 29700: reward = 162.0
Episode 29800: reward = 208.0
Episode 29900: reward = 248.0
Episode 30000: reward = 240.0
Episode 30100: reward = 219.0
Episode 30200: reward = 204.0
Episode 30300: reward = 204.0
Episode 30400: reward = 273.0
Episode 30500: reward = 314.0
Episode 30600: reward = 500.0
Episode 30700: reward = 500.0
Episode 30800: reward = 500.0
Episode 30900: reward = 500.0
Episode 31000: reward = 500.0
Episode 31100: reward = 500.0
Episode 31200: reward = 500.0
Episode 31