In [0]:
!pip install gym
!pip install gym[atari]
!pip install attrdict



In [0]:
import numpy as np
import torch
import gym
import torch.optim as optim
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import matplotlib
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
class Policy_Network(nn.Module):
  def __init__(self, input_dim, action_space):
    super(Policy_Network, self).__init__()
    self.fc1 = nn.Sequential(nn.Linear(input_dim, 300), nn.ReLU())
    self.fc2 = nn.Sequential(nn.Linear(300, 150), nn.ReLU())
    self.fc3 = nn.Sequential(nn.Linear(150, action_space))
  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

In [0]:
def get_action(logit):
  m = Categorical(logits=logit)
  return m.sample().item()

def compute_loss(NN, obs, actions,rewards, batch_size):
  logits = NN(obs)
  logprob = Categorical(logits=logits).log_prob(actions)
  return -(rewards * logprob).sum()/batch_size

def train(env_name = "Breakout-ram-v0", batch_size = 50, num_epoch = 300):
  env = gym.make(env_name)
  dim_obs = env.observation_space.shape[0]
  a_space = env.action_space.n
  NN = Policy_Network(dim_obs, a_space).to(device)
  optimizer = Adam(NN.parameters())
  for i in range(num_epoch):
    done = False
    obs = env.reset()
    batch_obs = []
    batch_acts = []
    batch_rewards = []
    ep_rewards = []
    epoch_ret = []
    while True:
      batch_obs.append(obs.copy())
      curr_action = get_action(NN(torch.as_tensor(obs, dtype=torch.float32, device = device)))
      obs, rew, done, _ = env.step(curr_action)
      batch_acts.append(curr_action)
      ep_rewards.append(rew)
      if done:
        ep_ret , ep_len = sum(ep_rewards) , len(ep_rewards)
        epoch_ret.append(ep_ret)
        batch_rewards += [ep_ret] * ep_len
        ep_rewards = []
        obs = env.reset()
        done = False
        if len(batch_obs) > batch_size:
          break
    batch_rewards = normalize([batch_rewards])
    optimizer.zero_grad()
    batch_loss = compute_loss(NN,torch.tensor(batch_obs, dtype=torch.float32, device = device),
                                torch.as_tensor(batch_acts, dtype=torch.int32, device = device),
                              torch.as_tensor(batch_rewards, dtype=torch.float32, device = device),
                              batch_size)
    batch_loss.backward()
    optimizer.step()
    mean_ret = sum(epoch_ret) / len(epoch_ret)
    print("Epoch: {} Epoch Avg Return: {}".format(i, mean_ret))
  return NN

In [0]:
NN = train(env_name="CartPole-v1",batch_size = 10000, num_epoch = 250)
torch.save(NN.state_dict(), "model.pt")

Epoch: 0 Epoch Avg Return: 23.585882352941177
Epoch: 1 Epoch Avg Return: 26.48941798941799
Epoch: 2 Epoch Avg Return: 28.971098265895954
Epoch: 3 Epoch Avg Return: 34.71626297577855
Epoch: 4 Epoch Avg Return: 38.81467181467181
Epoch: 5 Epoch Avg Return: 40.61290322580645
Epoch: 6 Epoch Avg Return: 44.72
Epoch: 7 Epoch Avg Return: 51.58762886597938
Epoch: 8 Epoch Avg Return: 53.90860215053763
Epoch: 9 Epoch Avg Return: 60.775757575757574
Epoch: 10 Epoch Avg Return: 62.422360248447205
Epoch: 11 Epoch Avg Return: 73.05839416058394
Epoch: 12 Epoch Avg Return: 80.9274193548387
Epoch: 13 Epoch Avg Return: 77.73643410852713
Epoch: 14 Epoch Avg Return: 84.88983050847457
Epoch: 15 Epoch Avg Return: 96.21153846153847
Epoch: 16 Epoch Avg Return: 107.07446808510639
Epoch: 17 Epoch Avg Return: 124.22222222222223
Epoch: 18 Epoch Avg Return: 130.27272727272728
Epoch: 19 Epoch Avg Return: 131.71052631578948
Epoch: 20 Epoch Avg Return: 158.28125
Epoch: 21 Epoch Avg Return: 143.5
Epoch: 22 Epoch Avg Ret