In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

np.set_printoptions(suppress=True, precision=4)

In [2]:
class Policy(nn.Module):

    def __init__(self, num_classes, lr):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(4, 8)
#         self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(8, num_classes)
        self.relu = nn.ReLU(inplace=True)
        self.softmax = nn.Softmax(dim=0) #important param to set --> dim
        self.optimizer = optim.SGD(self.parameters(), lr=lr)

    def forward(self, x):
        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.softmax(x)
        
        return x

def makePolicyNN(num_actions=2, lr=0.01):
    ''' Initialize the policy class '''
    assert isinstance(num_actions, int) and num_actions>0
    
    return Policy(num_actions, lr)

In [3]:
class PolicyGradient:
    def __init__(self, gamma, N=200, max_steps=500):
        ''' Initialize the cart-pole environment '''
        assert isinstance(gamma, float) and 0.0<gamma<1.0, 'Invalid gamma'
        assert isinstance(N, int) and N>0
        assert isinstance(max_steps, int) and max_steps>0
        
        self.env = gym.make ("CartPole-v1")
        self.gamma = gamma
        self.N = N
        self.max_steps = max_steps
        self.numActions = self.env.action_space.n
        self.num_steps = 0
        self.max_steps_reached = False
    
    def getAction(self, policy_network, state):
        ''' Return an action from a stochastic policy '''
        assert isinstance(state, np.ndarray) and len(state) == 4
        
        state = torch.from_numpy(state).float()
        probs = policy_network(state)
        m = Categorical(probs)
        action = m.sample() #type tensor
        log_prob_of_action = m.log_prob(action) #type tensor
        action = action.item() #extracting value from tensor
        
        assert isinstance(action, int) and action in [0,1]
        
        return action, log_prob_of_action
    
    def runEpisode(self, policy_network):
        ''' Generate [s_t, a_t, r_t] pairs for one episode '''
        
        state = self.env.reset()
        done = False
        rewards, log_prob_of_actions = [], []
        
        while not(done):
            a, lpa = self.getAction(policy_network, state)
            state, r, done, info = self.env.step(a)
            r = rewards.append(r)
            log_prob_of_actions.append(lpa)
            self.num_steps += 1 #incrementing total number of steps in one iteration
            if self.num_steps >= self.max_steps: #checking
                self.max_steps_reached = True
                break;
        
        return rewards, log_prob_of_actions    
    
    def improvePolicy(self, policy_network):
        ''' Improve policy by implementing vanilla version of Reinforce algo '''
        gamma = self.gamma
        n = 0
        objective = 0
        self.max_steps_reached = False
        
        while not(self.max_steps_reached):
            # Run an episode with policy and count number of steps taken
            r, log_prob_of_actions = self.runEpisode(policy_network)
            n += 1 #increasing the episode count by 1
            
            # Calculate discounted return (G_tau) and summation of log probs of actions from the episode
            G_tau = sum([ r[t] * gamma**t for t in range(len(r))])
            sum_lpa = sum(log_prob_of_actions)
            objective += -G_tau * sum_lpa
            assert isinstance(objective, torch.Tensor)
        
        print('objective, n = ', objective.item(), n)
        
        policy_network.optimizer.zero_grad()
        objective = objective/n #averaging objective over n episodes and flipping sign so it does gradient ascent
        objective.backward()
        policy_network.optimizer.step()
        
        return policy_network, objective.item()
    
    def doVanillaReinforce(self, policy_network):
        ''' Implement vanilla REINFORCE algo '''
        arr_objective = []
        
        for i in range(self.N):
            print('---------------Iteration %d---------------' %i)
            policy_network, objective = self.improvePolicy(policy_network)
            arr_objective.append(objective)
        
        return policy_network, arr_objective

# Need to calculate $ J(\theta) \approx 1/N \sum_{i=0}^{N} G(\tau_{i}) \sum_{t=0}^{T} log\pi_{\theta}(a_{t} | s_{t}) $
## $ G(\tau_{i}) = \sum_{t=0}^{T} \gamma^{t}r_{t}$

In [4]:
if __name__ == '__main__':
    render_env = False
    verbose = False
    # Initialize NN for policy
    policy_network = makePolicyNN(num_actions=2, lr=0.01)
    policy_network = policy_network
    
    # Do vanilla REINFORCE
    pg = PolicyGradient(gamma=0.9)
    policy_network, arr_objective = pg.doVanillaReinforce(policy_network)
    
    if verbose:
        fig, ax = plt.subplots(figsize=(11,7))
        ax.plot(arr_objective)
        plt.show()
    
    if render_env:
        state = pg.env.reset()
        done = False
        while not(done):
            a, _ = pg.getAction(policy_network, state)
            state, r, done, info = pg.env.step(a)
            pg.env.render()
            

---------------Iteration 0---------------
objective, n =  2744.764404296875 28
---------------Iteration 1---------------
objective, n =  0.8803046941757202 1
---------------Iteration 2---------------
objective, n =  0.871464192867279 1
---------------Iteration 3---------------
objective, n =  0.5460184812545776 1
---------------Iteration 4---------------
objective, n =  0.5420905947685242 1
---------------Iteration 5---------------
objective, n =  0.5381073355674744 1
---------------Iteration 6---------------
objective, n =  0.878662109375 1
---------------Iteration 7---------------
objective, n =  0.8743869662284851 1
---------------Iteration 8---------------
objective, n =  0.5491612553596497 1
---------------Iteration 9---------------
objective, n =  0.5451421737670898 1
---------------Iteration 10---------------
objective, n =  0.5410952568054199 1
---------------Iteration 11---------------
objective, n =  0.8785911202430725 1
---------------Iteration 12---------------
objective, n