In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import gym
from preprocessing import AtariPreprocessing
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from IPython import display
plt.ion()

# if gpu is to be used
device = torch.device("cpu")

---
### Helper functions

In [2]:
def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
    plt.pause(0.001)  # pause a bit so that plots are updated
    display.clear_output(wait=True)

In [3]:
def get_screen():   
    # transpose it into torch order (CHW)
    screen = env._get_obs().transpose((2, 0, 1))
    screen = torch.from_numpy(screen)
    
    # add a batch dimension (BCHW)
    return screen.unsqueeze(0).to(device)

"""
env.reset()
env.step(3)
plt.figure()
plt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')
plt.title('Example extracted screen')
plt.show()

env.step(3)
plt.figure()
plt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')
plt.title('Example extracted screen')
plt.show()

env.step(3)
plt.figure()
plt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')
plt.title('Example extracted screen')
plt.show()
"""

"\nenv.reset()\nenv.step(3)\nplt.figure()\nplt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')\nplt.title('Example extracted screen')\nplt.show()\n\nenv.step(3)\nplt.figure()\nplt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')\nplt.title('Example extracted screen')\nplt.show()\n\nenv.step(3)\nplt.figure()\nplt.imshow(np.average(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), axis=2), interpolation='none')\nplt.title('Example extracted screen')\nplt.show()\n"

---
### Code

In [4]:
# (state, action) -> (next_state, reward, done)
transition = namedtuple('transition', ('state', 'action', 'next_state', 'reward', 'done'))

# replay memory D with capacity N
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    # implemented as a cyclical queue
    def store(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.position] = transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [7]:
BATCH_SIZE = 128
GAMMA = 0.95
FRAME_SKIP = 4

# resize to 84, make greyscale and scale observations from 0 to 1
env = AtariPreprocessing(gym.make('BreakoutNoFrameskip-v0'), frame_skip=FRAME_SKIP, screen_size=84, grayscale_newaxis=True, scale_obs=True)

def one_hot(n, v):
    a = np.zeros(n)
    a[v] = 1.0
    return np.expand_dims(a, axis=0)

def rev_one_hot(a):
    return np.where(a[0] > 0)[0][0]

class DQN(nn.Module):
    def __init__(self, h, w, outputs, mem_len = 100000):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(FRAME_SKIP, 16, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(32)

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size, stride):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        
        convw = conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2)
        convh = conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2)
        linear_input_size = convw * convh * 32
        
        self.fc = nn.Linear(linear_input_size, 256)
        self.head = nn.Linear(256, outputs)
        
        self.memory = ReplayMemory(mem_len)
        self.optimizer = None
        self.target = None # to keep parameters frozen while propogating losses
        
        self.n_actions = outputs
        self.steps_done = 0
        
        self.EPS_START = 0.9
        self.EPS_END = 0.05
        self.EPS_DECAY = 50000 # in number of steps

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.fc(x.view(x.size(0), -1)))
        return self.head(x)
    
    def act(self, state, step_size=1):
        eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * (1. - min(1., self.steps_done / self.EPS_DECAY))
        self.steps_done += step_size #FRAME_SKIP

        # With probability eps select a random action
        if random.random() < eps_threshold:
            return torch.tensor([[random.randrange(self.n_actions)]], device=device, dtype=torch.long)

        # otherwise select action = maxa Q∗(φ(st), a; θ)
        with torch.no_grad():
            return self(state).max(1)[1].view(1, 1)
    
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        # in the form (state, action) -> (next_state, reward, done)
        transitions = self.memory.sample(BATCH_SIZE)
        batch = transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        next_state_batch = torch.cat(batch.next_state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        done_mask = np.array(batch.done)
        not_done_mask = torch.from_numpy(1 - done_mask).float().to(device)
        
        current_Q_values = self(state_batch).gather(1, action_batch)
        # Compute next Q value based on which goal gives max Q values
        # Detach variable from the current graph since we don't want gradients for next Q to propagated
        next_max_q = self.target(next_state_batch).detach().max(1)[0]
        next_Q_values = not_done_mask * next_max_q
        # Compute the target of the current Q values
        target_Q_values = reward_batch + (GAMMA * next_Q_values)
        # Compute Bellman error (using Huber loss)
        loss = F.smooth_l1_loss(current_Q_values, target_Q_values.unsqueeze(1))

        # Copy Q to target Q before updating parameters of Q
        self.target.load_state_dict(self.state_dict(), strict=False)
        
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

In [9]:
# Get number of actions and observations from gym action space
n_actions = env.action_space.n

init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape

# Initialize action-value function Q with random weights
dqnAgent = DQN(screen_height, screen_width, n_actions).to(device)
dqnAgent.target = DQN(screen_height, screen_width, n_actions).to(device)

# Optimizer
learning_rate = 2.5e-4
dqnAgent.optimizer = optim.RMSprop(dqnAgent.parameters(), lr=learning_rate)

num_episodes = 10000 # M
episode_durations = []

for i_episode in range(num_episodes):
    env.reset()
    state = get_screen()
    
    overall_reward = 0
    done = False
    while not done:
        #env.render()
        
        # Execute action a_t in emulator and observe reward r_t and image x_{t+1}
        action = dqnAgent.act(state)
        _, reward, done, _ = env.step(action.item())
        extrinsic_reward = torch.tensor([reward], device=device)
        
        overall_reward += reward

        # preprocess φ_{t+1} = φ(s_{t+1})
        next_state = get_screen()

        # Store transition (φt, at, rt, φt+1) in D
        dqnAgent.memory.store(state, action, next_state, extrinsic_reward, done)

        state = next_state

        dqnAgent.experience_replay()
        
    episode_durations.append(overall_reward)
    plot_durations()

print('Complete')
plot_durations()

KeyboardInterrupt: 