In [1]:
import numpy as np
import torch
import gym
from torch import nn as nn
from torch.utils import data
from torch.distributions import Bernoulli
from torch.distributions import Categorical

In [12]:
env = gym.make('CartPole-v0')

In [19]:
class policy(nn.Module):
  def __init__(self):
    super(policy,self).__init__()
    self.fc1 = nn.Linear(4,64)
    self.fc2 = nn.Linear(64,64)
    self.fc3 = nn.Linear(64,64)
    self.fc4 = nn.Linear(64,2)

  def forward(self,s):
    output1 = nn.functional.relu(self.fc1(s))
    output2 = output1 + nn.functional.leaky_relu(self.fc2(output1))
    output3 = output2 + nn.functional.leaky_relu(self.fc3(output2))
    output = nn.functional.softmax(self.fc4(output3),dim=-1)
    return output

In [20]:
class MCPGC():
  def __init__(self):
    self.pi = policy()
    self.gamma = 0.99
    self.optimizer = torch.optim.Adam(self.pi.parameters())
  
  def act(self,s):
    s = torch.from_numpy(s).float()
    p = self.pi(s)
    d = Categorical(p)
    a = d.sample()
    log_p = d.log_prob(a)
    return a, log_p

  def MC_update(self):
    s = env.reset()
    a, log_p = self.act(s)
    history = []
    G = 0
    while True:
      s1, r, done, _ = env.step(a.item())
      history.append((r, log_p))
      G += r
      a1, log_p1 = self.act(s1)
      if done:
        returns = 0
        return_set = []
        log_p_set = []
        for i in range(len(history))[::-1]:
          r, log_p = history[i]
          returns = self.gamma * returns + r
          log_p_set.append(log_p.view(-1))
          return_set.append(returns)
        log_probility = torch.cat(log_p_set)
        #print(log_probility)
        R = torch.FloatTensor(return_set)
        #print(R)
        loss = (torch.sum(torch.mul(log_probility, torch.autograd.Variable(R)).mul(-1), -1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        break
      s, a, log_p = s1, a1, log_p1
    return G

In [21]:
episodes = 300
agent = MCPGC()
for episode in range(episodes):
  G = agent.MC_update()
  print("Episode: {} | Final return: {}".format(episode,G))


Episode: 0 | Final return: 20.0
Episode: 1 | Final return: 13.0
Episode: 2 | Final return: 17.0
Episode: 3 | Final return: 15.0
Episode: 4 | Final return: 14.0
Episode: 5 | Final return: 16.0
Episode: 6 | Final return: 61.0
Episode: 7 | Final return: 25.0
Episode: 8 | Final return: 17.0
Episode: 9 | Final return: 12.0
Episode: 10 | Final return: 18.0
Episode: 11 | Final return: 35.0
Episode: 12 | Final return: 14.0
Episode: 13 | Final return: 9.0
Episode: 14 | Final return: 18.0
Episode: 15 | Final return: 57.0
Episode: 16 | Final return: 14.0
Episode: 17 | Final return: 43.0
Episode: 18 | Final return: 16.0
Episode: 19 | Final return: 18.0
Episode: 20 | Final return: 11.0
Episode: 21 | Final return: 11.0
Episode: 22 | Final return: 31.0
Episode: 23 | Final return: 45.0
Episode: 24 | Final return: 12.0
Episode: 25 | Final return: 37.0
Episode: 26 | Final return: 15.0
Episode: 27 | Final return: 60.0
Episode: 28 | Final return: 15.0
Episode: 29 | Final return: 28.0
Episode: 30 | Final r