# DQN Tutorial

This notebook is for me to continue my practice with PyTorch--while learning about how to train a Deep Q-Network! Exciting stuff. I just finished reviewing the chapters on reinforcement learning (chapters 17 & 21) in the textbook "Artificial Intelligence A Modern Approach", 3rd ed. (Russel & Norvig). While it isn't the most up-to-date reference on the subject, it does give the foundations of reinforcement learning--in particular, it covers the temporal difference update equation, which is used as an objective function in a DQN. 

So: here goes!

In [None]:
%matplotlib inline

import collections
import itertools
import copy
import random
import typing

import gym
import math
import numpy
import matplotlib
import matplotlib.pyplot as pyplot
import PIL
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as functional
import torch.autograd as autograd
import torchvision.transforms as transforms

In [None]:
env = gym.make('CartPole-v0').unwrapped

FloatTensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor
ByteTensor = torch.cuda.ByteTensor
Tensor = FloatTensor

So *this* is how easy it is to set up a gym environment. That's super cool. 

In [None]:
class Transition(typing.NamedTuple):
    state: torch.Tensor
    action: torch.Tensor
    next_state: torch.Tensor
    reward: torch.Tensor

class ReplayMemory(object):
    capacity: int
    memory: typing.List[Transition]
    position: int
    
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.memory = []
        self.position = 0
    
    def push(self, transition: Transition):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = transition
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size: int) -> typing.Iterable[Transition]:
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(448, 2)
    
    def forward(self, x):
        x = functional.relu(self.bn1(self.conv1(x)))
        x = functional.relu(self.bn2(self.conv2(x)))
        x = functional.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

Before I forget to write it down, I'm trying to use type hints throughout the code. It makes everything clearer in my head (because, contrary to popular belief in the Python community, you *can't* stop worrying about types in Python). Here are some things I learned so far:
* `typing` has its own `NamedTuple` class from which you create a subclass, but this allows you to put type hints on your variables;
* user-defined classes can be used as-is, without creating a special type variable, in type hints;
* there is no type hint for having no return type (you just don't use a type hint);
* return types for functions are written like `def function(...) -> returnType: blablacodeblabla`.

And now, back to the tutorial. 

In [None]:
pyplot.ion()

resize = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.Scale(40, interpolation=PIL.Image.CUBIC),
     transforms.ToTensor()])

screen_width = 600

def get_cart_location():
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)

def get_screen():
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    screen = screen[:, 160:320]
    view_width = 320
    cart_location = get_cart_location()
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    screen = screen[:, :, slice_range]
    screen = numpy.ascontiguousarray(screen, dtype=numpy.float32) / 255
    screen = torch.from_numpy(screen)
    return resize(screen).unsqueeze(0).type(Tensor)
env.reset()
pyplot.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
           interpolation='none')
pyplot.title('Example extracted screen')
pyplot.show()

In [None]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200

model = DQN()

In [None]:
model.cuda()

In [None]:
optimizer = optim.RMSprop(model.parameters())
memory = ReplayMemory(10000)

steps_done = 0

def select_action(state: torch.Tensor) -> typing.Union[LongTensor, autograd.Variable]:
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
    -1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        return model(
            autograd.Variable(state, volatile=True).type(FloatTensor)
        ).data.max(1)[1].view(1, 1)
    else:
        return LongTensor([[random.randrange(2)]])

episode_durations = []

def plot_durations():
    pyplot.figure(2)
    pyplot.clf()
    durations_t = torch.FloatTensor(episode_durations)
    pyplot.title('Training...')
    pyplot.xlabel('Episode')
    pyplot.ylabel('Duration')
    pyplot.plot(durations_t.numpy())
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        pyplot.plot(means.numpy())
    
    pyplot.pause(0.001)
    import IPython
    IPython.display.clear_output(wait=True)
    IPython.display.display(pyplot.gcf())

In [None]:
last_sync = 0

def optimize_model():
    global last_sync
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)))
    non_final_next_states = autograd.Variable(
        torch.cat([s for s in batch.next_state if s is not None]),
        volatile=True)
    state_batch = autograd.Variable(torch.cat(batch.state))
    action_batch = autograd.Variable(torch.cat(batch.action))
    reward_batch = autograd.Variable(torch.cat(batch.reward))
    
    state_action_values = model(state_batch).gather(1, action_batch)
    next_state_values = autograd.Variable(torch.zeros(BATCH_SIZE).type(Tensor))
    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
    next_state_values.volatile = False
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    loss = functional.smooth_l1_loss(state_action_values, expected_state_action_values)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
num_episodes = 1000
for i_episode in range(num_episodes):
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    for t in range(1000):
        action = select_action(state)
        _, reward, done, _ = env.step(action[0, 0])
        reward = Tensor([reward])
        last_screen = current_screen
        current_screen = get_screen()
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None
        
        memory.push(Transition(state, action, next_state, reward))
        
        state = next_state
        
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
env.render(close=True)
env.close()
pyplot.ioff()
pyplot.show()

AND THAT'S IT! Going through this tutorial was mostly a way to prove to myself that this stuff is easier than I thought (as in, I have enough technical background to understand the details of what's going on). I think the next thing to do is try to really do well in this simple game. Also, try a few modifications:
* maybe I'll just give it the image instead of the difference between two subsequent images;
* try incorporating a recurrent network...somehow (probably have to read up again on that before I tried it);
* try different network architectures;
* try slight training variations.