In [0]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import torchvision
import torchvision.transforms as tfms
import matplotlib.pyplot as plt

from collections import namedtuple

device = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
Feedback = namedtuple('feedback', ['state', 'reward', 'done', 'info'])
Record = namedtuple('record', ['action', 'state', 'reward', 'done'])
Transition = namedtuple('transition', ['state', 'action', 'next_state', 'reward', 'done'])

class Transform:
  def __init__(self):
    self.img_transform = tfms.Compose([
      tfms.ToPILImage(),
      tfms.Grayscale(),
      tfms.Resize((42, 55)),
      tfms.CenterCrop(42),
      tfms.ToTensor()
    ])
  
  def __call__(self, framestack):
    imgs = []
    for stack in framestack:
      img = self.img_transform(stack)[0]
      imgs.append(img)
    return torch.stack(imgs)

class ExpBuffer:
  def __init__(self, max_size = 100000):
    self.max_size = 100000
    self.records = []
    self.state_shape = None
  
  def add_record(self, action, state, reward, done):
    if len(self.records) == self.max_size: self.records.pop(0)
    self.state_shape = state.shape
    self.records.append(Record(action, state, reward, done))
  
  def add_state(self, state):
    self.add_record(None, state, None, False)

  def sample(self, batch_size, device='cpu'):
    if len(self.records) <= 1: raise Error('Sampling before buffer is filled') 
    states = torch.zeros(batch_size, *buffer.state_shape)
    actions = torch.zeros(batch_size, dtype=torch.long)
    next_states = torch.zeros(batch_size, *buffer.state_shape)
    rewards = torch.zeros(batch_size)
    done = torch.zeros(batch_size, dtype=torch.bool)

    for i in range(batch_size):
      while True:
        idx = np.random.randint(0, len(self.records))
        if idx != 0 and self.records[idx].action is not None: break
      record = self.records[idx]
      states[i] = self.records[idx-1].state
      actions[i] = record.action
      rewards[i] = record.reward
      if record.done: done[i] = True
      else: next_states[i] = record.state

    return Transition(states.to(device), actions.to(device), 
                      next_states.to(device), rewards.to(device), 
                      done.to(device))
    
    
    return Transition(self.records[idx-1].state, *self.records[idx])
    
  
  def __len__(self): return len(self.records)

def get_DQN(nb_actions): 
  return nn.Sequential(
    # 42 * 42
    nn.Conv2d(4, 16, 7, 2, 3), # 21 * 21
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.MaxPool2d(2), # 10 * 10
    nn.Conv2d(16, 64, 3, 2, 1), # 5 * 5
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(2), # 2 * 2
    nn.Flatten(), # 2 * 2 * 32
    nn.Linear(2 * 2 * 64, 256, True),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256, nb_actions)
  )

In [0]:
def get_action():
  if np.random.rand() < args.epsilon:
    action = np.random.randint(nb_actions)
  else:
    model.eval()
    Q = model(state.unsqueeze(0).to(device))
    action = torch.max(Q, 1)[1].item()

def optimize():
  # train using past transitions
  model.train()
  optimizer.zero_grad()
  # randomly sample a batch of transitions from buffer
  batch = buffer.sample(args.batch_size, device)

  # predicted value of next state
  Qs = model(batch.state)
  outputs = Qs[np.arange(args.batch_size), batch.action]
  # estimated value of next state
  next_Qs = torch.max(model(batch.next_state), 1)[0]
  targets = args.gamma * batch.reward + next_Qs * ~batch.done

  # calculate loss
  loss = loss_func(outputs, targets)
  loss.backward()
  optimizer.step()

In [0]:
class Args: pass
args = Args()
args.lr = 0.003
args.l2reg = 0.0003
args.episodes = 10
args.max_steps = 150
args.epsilon = 0.1
args.batch_size = 32
args.gamma = 0.99

In [0]:
transform = Transform()
buffer = ExpBuffer()

env = gym.make('Pong-v0')
env = gym.wrappers.FrameStack(env, 4)
nb_actions = env.action_space.n

model = get_DQN(nb_actions)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), args.lr)
# Huber Loss: equivalent to MSE when difference is small, but less punishing
# when the difference is large, resilient to outliers
loss_func = F.smooth_l1_loss

In [162]:
%%prun
for episode_idx in range (args.episodes):
  state = env.reset()
  state = transform(state)
  buffer.add_state(state)
  score = 0

  for t in range(args.max_steps):
    # calculate next action using epsilon-greedy
    get_action()

    # step
    state, reward, done, info = env.step(action)
    state = transform(state)
    buffer.add_record(action, state, reward, done)
    score += reward
    
    optimize()
  
  print('episode %d, score %f' % (episode_idx, score))
    

episode 0, score -2.000000
episode 1, score -2.000000
episode 2, score -2.000000
episode 3, score -2.000000
episode 4, score -2.000000
episode 5, score -2.000000
episode 6, score -2.000000
episode 7, score -2.000000
episode 8, score -2.000000
episode 9, score -2.000000
 