### Implementation of the Actor Critic Policy Gradient Algorithm on Cartpole ###

In [None]:
## Importing the necessary packages ##

import gym
import torch
import torch.nn as nn

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

First off we are going to define out Actor Critic Network.

In [None]:
## Defining the actor critic network ##

class ActorCriticNetwork(nn.Module):
    '''
    Defines the AC network with a common base and
    two heads.
    '''
    
    def __init__(self , obs_dim , action_dim , hidden_layer = 128):
        '''
        Parameters:
        obs_dim : Dimension of the observation space.
        action_dim : Dimension of the action space.
        hidden_layer : Number of hidden neurons.
        '''
        super().__init__()
        
        self.base = nn.Sequential(nn.Linear(obs_dim , hidden_layer) ,
                                  nn.ReLU() , 
                                  nn.Linear(hidden_layer , hidden_layer) ,
                                  nn.ReLU())
        
        self.critic = nn.Linear(hidden_layer , 1)
        self.actor = nn.Sequential(nn.Linear(hidden_layer , action_dim) ,
                                         nn.Softmax())
        
    def forward(self , obs):
        
        base_out = self.base(obs)
        
        state_val = self.critic(base_out)
        
        action_prob = self.actor(base_out)
        
        return state_val , action_prob

With the network defined , we need to set up our Agent class.

This is going to be the most important part of this implementation.

In [None]:
## Defining the Agent ##

class ACAgent:
    '''
    Defines the Actor Critic Agent.
    '''
    
    def __init__(self , env , gamma = 0.99 , learning_rate = 1e-5):
        
        self.env = env
        
        self.action = None
        
        self.observation_dim = self.env.observation_space.shape[0] 
        
        self.action_dim = self.env.action_space.n
        
        self.ac_net = ActorCriticNetwork(self.observation_dim , self.action_dim)
        
        self.gamma = gamma
        
        self.learning_rate = learning_rate
        
        self.optim = torch.optim.Adam(self.ac_net.parameters() , lr = self.learning_rate)
        
    def choose_action(self , observation):
        
        obs_tensor = torch.Tensor([observation])
        
        _ , action_prob = self.ac_net(obs_tensor)
        
        action_distribution = torch.distributions.categorical.Categorical(action_prob)
        
        action = action_distribution.sample()
        
        self.action = action
        
        return action.item()
        
    def learn(self , state , reward , done , next_state):
        
        state_tensor = torch.Tensor([state])
        reward_tensor = torch.tensor(reward)
        done_tensor = torch.Tensor([done]) 
        next_state_tensor = torch.Tensor([next_state])
        
        state_value , action_prob = self.ac_net(state_tensor)
        
        next_state_val , _ = self.ac_net(next_state_tensor)
        
        delta = reward_tensor + self.gamma * next_state_val * (1 - done) - state_value
        
        self.optim.zero_grad()
        
        critic_loss = delta ** 2
        
        action_distribution = torch.distributions.categorical.Categorical(action_prob)
        log_p = action_distribution.log_prob(self.action)
        actor_loss = -log_p * delta
        
        total_loss = actor_loss + critic_loss
        
        total_loss.backward()
        
        self.optim.step()

In [None]:
## Now setting up the environment ##

env = gym.make('CartPole-v0')

num_step = 1000

total_mean_reward = []

agent = ACAgent(env)

for i in range(num_step):
    
    episode_reward = []
    
    obs = env.reset()
    
    while True:
        
        action = agent.choose_action(obs)
        
        next_obs , reward , done , _ = env.step(action)
        
        episode_reward.append(reward)
        
        agent.learn(obs , reward , done , next_obs)
        
        if done:
            break
            
        obs = next_obs

And done!!