In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
import warnings
warnings.simplefilter("ignore")
%matplotlib inline
print(torch.__version__)

0.4.1


In [2]:
env = gym.make('CartPole-v1')
env.seed(1); torch.manual_seed(1);

In [3]:
# Hyperparameters for model
l_rate = 0.01
gamma_value = 0.99

In [4]:
class PolicyGradient(nn.Module):
    def __init__(self):
        super(PolicyGradient, self).__init__()
        
        # Define the action space and state space
        self.action_space = env.action_space.n
        self.state_space = env.observation_space.shape[0]
        
        self.l1 = nn.Linear(self.state_space, 128, bias=False)
        self.l2 = nn.Linear(128, self.action_space, bias=False)
        
        self.gamma_value = gamma_value
        
        # Episode policy and reward history 
        self.history_policy = Variable(torch.Tensor()) 
        self.reward_episode = []
        
        # Overall reward and loss history
        self.history_reward = []
        self.history_loss = []

    def forward(self, x):    
        model = torch.nn.Sequential(
            self.l1,
            nn.Dropout(p=0.5),
            nn.ReLU(),
            self.l2,
            nn.Softmax(dim=-1)
        )
        return model(x)
    
policy = PolicyGradient()
optimizer = optim.Adam(policy.parameters(), lr=l_rate)

In [5]:
def choose_action(state):
    # Run the policy model and choose an action based on the probabilities in state
    state = torch.from_numpy(state).type(torch.FloatTensor)
    state = policy(Variable(state))
    c = Categorical(state)
    action = c.sample()   
    if policy.history_policy.dim() != 0:
        try:
            policy.history_policy = torch.cat([policy.history_policy, c.log_prob(action)])
        except:
            policy.history_policy = (c.log_prob(action))
    else:
        policy.history_policy = (c.log_prob(action))
    return action

def update_policy():
    R = 0
    rewards = []
    
    # Discount future rewards back to the present using gamma
    for r in policy.reward_episode[::-1]:
        R = r + policy.gamma_value * R
        rewards.insert(0,R)
        
    # Scale rewards
    rewards = torch.FloatTensor(rewards)
    x = np.finfo(np.float32).eps
    x = np.array(x)
    x = torch.from_numpy(x)
    rewards = (rewards - rewards.mean()) / (rewards.std() + x)

    # Calculate the loss loss
    loss = (torch.sum(torch.mul(policy.history_policy, Variable(rewards)).mul(-1), -1))
    
    # Update the weights of the network
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #Save and intialize episode history counters
    policy.history_loss.append(loss.data[0])
    policy.history_reward.append(np.sum(policy.reward_episode))
    policy.history_policy = Variable(torch.Tensor())
    policy.reward_episode= []
    
def main_function(episodes):
    running_total_reward = 10
    for e in range(episodes):
        # Reset the environment and record the starting state
        state = env.reset() 
        done = False       
    
        for time in range(1000):
            action = choose_action(state)
            # Step through environment using chosen action
            state, reward, done, _ = env.step(action.data.item())

            # Save reward
            policy.reward_episode.append(reward)
            if done:
                break
        
        # Used to determine when the environment is solved.
        running_total_reward = (running_total_reward * 0.99) + (time * 0.01)

        update_policy()

        if e % 50 == 0:
            print('Episode number {}, Last length: {:5d}, Average length: {:.2f}'.format(e, time, running_total_reward))

        if running_total_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and the last episode runs to {} time steps!".format(running_total_reward, time))
            break

episodes = 1000
main_function(episodes)

Episode number 0, Last length:    19, Average length: 10.09
Episode number 50, Last length:    13, Average length: 12.81
Episode number 100, Last length:    10, Average length: 11.77
Episode number 150, Last length:    10, Average length: 11.42
Episode number 200, Last length:     8, Average length: 10.69
Episode number 250, Last length:    13, Average length: 11.32
Episode number 300, Last length:    11, Average length: 11.07
Episode number 350, Last length:     9, Average length: 10.76
Episode number 400, Last length:     8, Average length: 10.34
Episode number 450, Last length:     8, Average length: 10.06
Episode number 500, Last length:    10, Average length: 9.68
Episode number 550, Last length:     9, Average length: 9.57
Episode number 600, Last length:     9, Average length: 9.31
Episode number 650, Last length:     9, Average length: 9.28
Episode number 700, Last length:     8, Average length: 9.10
Episode number 750, Last length:     9, Average length: 9.05
Episode number 80