In [1]:
!pip install gym
!pip install gym[atari]
!pip install attrdict

Collecting cloudpickle~=1.2.0
  Downloading cloudpickle-1.2.2-py2.py3-none-any.whl (25 kB)
Installing collected packages: cloudpickle
  Attempting uninstall: cloudpickle
    Found existing installation: cloudpickle 1.5.0
    Uninstalling cloudpickle-1.5.0:
      Successfully uninstalled cloudpickle-1.5.0
[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

spyder 3.3.6 requires pyqt5<5.13; python_version >= "3", which is not installed.
spyder 3.3.6 requires pyqtwebengine<5.13; python_version >= "3", which is not installed.
spinup 0.2.0 requires cloudpickle==1.2.1, but you'll have cloudpickle 1.2.2 which is incompatible.
spinup 0.2.0 requires matplotlib==3.1.1, but you'll have matplotlib 3.0.3 which is incompatible.
spinup 0.2.0 requires torch=

In [1]:
import numpy as np
import torch
import gym
import torch.optim as optim
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import matplotlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def reward_to_go(reward_list):
    """
        assuming that the rewards are put in chornological order
    """
    rtg_list = []
    tot_sum = 0
    for item in reversed(reward_list):
        tot_sum += item
        rtg_list.append(tot_sum)
    return list(reversed(rtg_list))

In [3]:
class Policy_Network(nn.Module):
  def __init__(self, input_dim, action_space):
    super(Policy_Network, self).__init__()
    self.fc1 = nn.Sequential(nn.Linear(input_dim, 300), nn.ReLU())
    self.fc2 = nn.Sequential(nn.Linear(300, 150), nn.ReLU())
    self.fc3 = nn.Sequential(nn.Linear(150, action_space))
  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

In [4]:
def get_action(logit):
  m = Categorical(logits=logit)
  return m.sample().item()

def compute_loss(NN, obs, actions,rewards, batch_size):
  logits = NN(obs)
  logprob = Categorical(logits=logits).log_prob(actions)
  return -(rewards * logprob).sum()/batch_size

def train(env_name = "Breakout-ram-v0", batch_size = 50, num_epoch = 300):
  env = gym.make(env_name)
  dim_obs = env.observation_space.shape[0]
  a_space = env.action_space.n
  NN = Policy_Network(dim_obs, a_space).to(device)
  optimizer = Adam(NN.parameters())
  for i in range(num_epoch):
    done = False
    obs = env.reset()
    batch_obs = []
    batch_acts = []
    batch_rewards = []
    ep_rewards = []
    epoch_ret = []
    while True:
      batch_obs.append(obs.copy())
      curr_action = get_action(NN(torch.as_tensor(obs, dtype=torch.float32, device = device)))
      obs, rew, done, _ = env.step(curr_action)
      batch_acts.append(curr_action)
      ep_rewards.append(rew)
      if done:
        ep_ret , ep_len = sum(ep_rewards) , len(ep_rewards)
        epoch_ret.append(ep_ret)
        batch_rewards += reward_to_go(ep_rewards)
        ep_rewards = []
        obs = env.reset()
        done = False
        if len(batch_obs) > batch_size:
          break
    batch_rewards = normalize([batch_rewards])
    optimizer.zero_grad()
    batch_loss = compute_loss(NN,torch.tensor(batch_obs, dtype=torch.float32, device = device),
                                torch.as_tensor(batch_acts, dtype=torch.int32, device = device),
                              torch.as_tensor(batch_rewards, dtype=torch.float32, device = device),
                              batch_size)
    batch_loss.backward()
    optimizer.step()
    mean_ret = sum(epoch_ret) / len(epoch_ret)
    print("Epoch: {} Epoch Avg Return: {}".format(i, mean_ret))
  return NN

In [5]:
NN = train(env_name="CartPole-v1",batch_size = 10000, num_epoch = 250)
torch.save(NN.state_dict(), "model.pt")

Epoch: 0 Epoch Avg Return: 23.60141509433962
Epoch: 1 Epoch Avg Return: 26.69333333333333
Epoch: 2 Epoch Avg Return: 29.925373134328357
Epoch: 3 Epoch Avg Return: 35.041811846689896
Epoch: 4 Epoch Avg Return: 36.845588235294116
Epoch: 5 Epoch Avg Return: 37.62406015037594
Epoch: 6 Epoch Avg Return: 42.97424892703863
Epoch: 7 Epoch Avg Return: 44.644444444444446
Epoch: 8 Epoch Avg Return: 51.18877551020408
Epoch: 9 Epoch Avg Return: 53.4468085106383
Epoch: 10 Epoch Avg Return: 62.99375
Epoch: 11 Epoch Avg Return: 66.24503311258277
Epoch: 12 Epoch Avg Return: 78.578125
Epoch: 13 Epoch Avg Return: 76.65648854961832
Epoch: 14 Epoch Avg Return: 85.52991452991454
Epoch: 15 Epoch Avg Return: 97.625
Epoch: 16 Epoch Avg Return: 111.15555555555555
Epoch: 17 Epoch Avg Return: 129.92207792207793
Epoch: 18 Epoch Avg Return: 140.69444444444446
Epoch: 19 Epoch Avg Return: 141.2112676056338
Epoch: 20 Epoch Avg Return: 165.21311475409837
Epoch: 21 Epoch Avg Return: 165.5737704918033
Epoch: 22 Epoch Avg

Epoch: 184 Epoch Avg Return: 436.5652173913044
Epoch: 185 Epoch Avg Return: 441.82608695652175
Epoch: 186 Epoch Avg Return: 461.40909090909093
Epoch: 187 Epoch Avg Return: 487.7142857142857
Epoch: 188 Epoch Avg Return: 497.04761904761904
Epoch: 189 Epoch Avg Return: 496.42857142857144
Epoch: 190 Epoch Avg Return: 495.2857142857143
Epoch: 191 Epoch Avg Return: 480.95238095238096
Epoch: 192 Epoch Avg Return: 497.5238095238095
Epoch: 193 Epoch Avg Return: 489.9047619047619
Epoch: 194 Epoch Avg Return: 487.04761904761904
Epoch: 195 Epoch Avg Return: 500.0
Epoch: 196 Epoch Avg Return: 500.0
Epoch: 197 Epoch Avg Return: 488.6190476190476
Epoch: 198 Epoch Avg Return: 467.59090909090907
Epoch: 199 Epoch Avg Return: 487.14285714285717
Epoch: 200 Epoch Avg Return: 466.54545454545456
Epoch: 201 Epoch Avg Return: 405.32
Epoch: 202 Epoch Avg Return: 475.22727272727275
Epoch: 203 Epoch Avg Return: 475.77272727272725
Epoch: 204 Epoch Avg Return: 461.59090909090907
Epoch: 205 Epoch Avg Return: 471.590