In [1]:
import numpy as np
import gym
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch.optim import Adam
import matplotlib.pyplot as plt

In [2]:
env = gym.make('CartPole-v1')

In [3]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.observation_space = env.observation_space.shape[0]
        self.action_space = env.action_space.n 
        self.linear_relu_stack = nn.Sequential(
        nn.Linear(self.observation_space,64), #first hidden layer
        nn.ReLU(),
        nn.Linear(64,64), #2nd hidden layer
        nn.ReLU(),
        nn.Linear(64,self.action_space) #output layer
        )
        
    def forward(self,x):
        logits = self.linear_relu_stack(x)
        return logits
        
def get_policy(obs, net):
    '''Return a categorical distribution over actions for a given observation'''
    logits = net(obs)
    return Categorical(logits=logits)
    
def get_action(obs,net):
    '''Samples an action from the categorical distribution resulting from an observation'''
    return get_policy(obs,net).sample().item()

def compute_loss(obs,act,weights,net):
    '''Computes a 'pseudo' loss function whose gradient is equal to the policy gradient
        this isn't really a loss function, but it allows us to use pytorch's autograd tools'''
    logp = get_policy(obs,net).log_prob(act)
    return -(logp * weights).mean()

In [4]:
def sample_trajectories(batch_size, net):
    '''Sample a number of trajectores, and return the data stored'''
    batch_obs = []
    batch_acts = []
    batch_rets = []
    batch_weights = [] # In this case, is just the full return R(tau)
    batch_lens = []
    
    # Render first episode of each batch
    rendered_first_episode = False
    
    for ep in range(batch_size):
        ep_rews = []
        
        
        obs = env.reset()
        done = False
        
        while not done:
            batch_obs.append(obs.copy())
            
            if not rendered_first_episode:
                env.render()
            
            act = get_action(torch.as_tensor(obs,dtype=torch.float32), net)
            obs, rew, done, _ = env.step(act)
            
            batch_acts.append(act)
            ep_rews.append(rew)
        

        rendered_first_episode = True
        ep_ret, ep_len = sum(ep_rews), len(ep_rews)
        batch_rets.append(ep_ret)
        batch_lens.append(ep_len)
        
        batch_weights += [ep_ret] * ep_len
    
    return batch_obs, batch_acts, batch_weights, batch_rets, batch_lens

In [5]:
def train_one_epoch(batch_size, optimizer, net):
    
    batch_obs, batch_acts, batch_weights, batch_rets, batch_lens = sample_trajectories(batch_size, net)
    
    
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                              act=torch.as_tensor(batch_acts, dtype=torch.int32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32),
                              net=net
                            )
    batch_loss.backward()
    optimizer.step()
    
    return batch_loss, batch_rets, batch_lens    

In [6]:
p_net = PolicyNetwork()
batch_size = 100
n_epochs = 50
learning_rate = 0.01
optimizer = Adam(p_net.parameters(), lr=learning_rate)

for epoch in range(n_epochs):
    loss,rets,lens = train_one_epoch(batch_size,optimizer,p_net)
    if epoch%10==0:
        print(f'epoch: {epoch} \t length: {np.mean(lens)} \t returns {np.mean(rets)}')
env.close()

epoch: 0 	 length: 20.74 	 returns 20.74
epoch: 10 	 length: 68.63 	 returns 68.63
epoch: 20 	 length: 235.56 	 returns 235.56
epoch: 30 	 length: 495.99 	 returns 495.99
epoch: 40 	 length: 500.0 	 returns 500.0
