Please know that DQN is very unstable and fails to learn in Atari games almost all the time. This notebook isn't a fully tested example.

In [34]:
!pip install -q gym[atari]
!pip install -q gym[accept-rom-license]

In [35]:
# Reference: https://github.com/davidreiman/pytorch-atari-dqn/blob/master/dqn.ipynb

%matplotlib inline

import os
import re
import gym
import time
import copy
import random
import warnings
import numpy as np

import torch
import torch.nn as nn

from IPython import display
from skimage.color import rgb2gray
from skimage.transform import rescale
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm
from collections import deque, namedtuple

In [36]:
plt.style.use('seaborn')
warnings.filterwarnings('ignore')

In [37]:
class DeepQNetwork(nn.Module):
    def __init__(self, num_frames, num_actions):
        super(DeepQNetwork, self).__init__()
        self.num_frames = num_frames
        self.num_actions = num_actions

        # Layers
        self.conv1 = nn.Conv2d(
            in_channels=num_frames,
            out_channels=16,
            kernel_size=8,
            stride=4,
            padding=2
            )
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=4,
            stride=2,
            padding=1
            )
        self.fc1 = nn.Linear(
            in_features=3200,
            out_features=256,
            )
        self.fc2 = nn.Linear(
            in_features=256,
            out_features=num_actions,
            )

        # Activation Functions
        self.relu = nn.ReLU()

    def flatten(self, x):
        batch_size = x.size()[0]
        x = x.view(batch_size, -1)
        return x

    def forward(self, x):

        # Forward pass
        x = self.relu(self.conv1(x))  # In: (80, 80, 4)  Out: (20, 20, 16)
        x = self.relu(self.conv2(x))  # In: (20, 20, 16) Out: (10, 10, 32)
        x = self.flatten(x)           # In: (10, 10, 32) Out: (3200,)
        x = self.relu(self.fc1(x))    # In: (3200,)      Out: (256,)
        x = self.fc2(x)               # In: (256,)       Out: (4,)

        return x

In [38]:
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'terminal', 'next_state'])

In [45]:
class Agent:
    def __init__(self, model, buffer_size, lr, gamma, epsilon_i, epsilon_f, anneal_time, ckptdir):

        self.cuda = True if torch.cuda.is_available() else False

        self.model = model
        self.device = torch.device("cuda" if self.cuda else "cpu")

        if self.cuda:
            self.model = self.model.cuda()

        self.buffer_size = buffer_size
        self.gamma = torch.tensor([gamma], device=self.device)
        self.eps_i = epsilon_i
        self.eps_f = epsilon_f
        self.anneal_time = anneal_time
        self.ckptdir = ckptdir

        if not os.path.isdir(ckptdir):
            os.makedirs(ckptdir)

        self.replay_buffer = deque(maxlen=buffer_size)
        self.clone()

        self.loss = nn.SmoothL1Loss()
        self.opt = torch.optim.Adam(self.model.parameters(), lr=lr)

    def clone(self):
        try:
            del self.clone_model
        except:
            pass

        self.clone_model = copy.deepcopy(self.model)

        for p in self.clone_model.parameters():
            p.requires_grad = False

        if self.cuda:
            self.clone_model = self.clone_model.cuda()

    def remember(self, *args):
        self.replay_buffer.append(Transition(*args))

    def retrieve(self, batch_size):
        transitions = random.sample(self.replay_buffer, batch_size)
        batch = Transition(*zip(*transitions))
        state, action, reward, terminal, next_state = map(torch.cat, [*batch])
        return state, action, reward, terminal, next_state

    def act(self, state):
        q_values = self.model(state).detach()
        action = torch.argmax(q_values)
        return action.item()

    def process(self, state):
        state = state[35:195]
        state = rescale(state, scale=0.5)
        state = state[np.newaxis, np.newaxis, :, :]
        return torch.tensor(state, device=self.device, dtype=torch.float)

    def exploration_rate(self, t):
        if 0 <= t < self.anneal_time:
            return self.eps_i - t*(self.eps_i - self.eps_f)/self.anneal_time
        elif t >= self.anneal_time:
            return self.eps_f
        elif t < 0:
            return self.eps_i

    def save(self, t):
        save_path = os.path.join(self.ckptdir, 'model-{}'.format(t))
        torch.save(self.model.state_dict(), save_path)

    def load(self):
        ckpts = [file for file in os.listdir(self.ckptdir) if 'model' in file]
        steps = [int(re.search('\d+', file).group(0)) for file in ckpts]

        latest_ckpt = ckpts[np.argmax(steps)]
        self.t = np.max(steps)

        print("Loading checkpoint: {}".format(latest_ckpt))

        self.model.load_state_dict(torch.load(os.path.join(self.ckptdir, latest_ckpt)))

    def update(self, batch_size):
        self.model.zero_grad()

        state, action, reward, terminal, next_state = self.retrieve(batch_size)
        q = self.model(state).gather(1, action.view(batch_size, 1))
        qmax = self.clone_model(next_state).max(dim=1)[0]

        nonterminal_target = reward + self.gamma*qmax
        terminal_target = reward

        target = terminal.float() * terminal_target + (~terminal).float() * nonterminal_target

        loss = self.loss(q.view(-1), target)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

    def learn(
        self, env,
        num_frames,
        num_actions,
        episodes,
        batch_size,
        learning_start,

        update_interval=4,
        clone_interval=1e4,
        save_interval=1e5,

        train=True,
        render=False,
        plot=False,
        load=False,
    ):

        self.t = 0
        metadata = dict(episode=[], reward=[])

        if load:
            self.load()

        try:
            progress_bar = tqdm(range(episodes), unit='episode')

            i = 0
            for episode in progress_bar:

                state = env.reset()
                state = self.process(state)

                done = False
                total_reward = 0

                while not done:

                    if render:
                        env.render()

                    while state.size()[1] < num_frames:
                        action = 1 # Fire

                        new_frame, reward, done, _ = env.step(action)
                        new_frame = self.process(new_frame)

                        state = torch.cat([state, new_frame], 1)

                    if train and np.random.uniform() < self.exploration_rate(self.t - learning_start):
                        action = np.random.choice(num_actions)

                    else:
                        action = self.act(state)

                    new_frame, reward, done, _ = env.step(action)
                    new_frame = self.process(new_frame)

                    new_state = torch.cat([state, new_frame], 1)
                    new_state = new_state[:, 1:, :, :]

                    if train:
                        reward = torch.tensor([reward], device=self.device, dtype=torch.float)
                        action = torch.tensor([action], device=self.device, dtype=torch.long)
                        done = torch.tensor([done], device=self.device, dtype=torch.uint8)

                        self.remember(state, action, reward, done, new_state)

                    state = new_state
                    total_reward += reward
                    self.t += 1
                    i += 1

                    if not train:
                        time.sleep(0.1)

                    if train and self.t > learning_start and i > batch_size:

                        if self.t % update_interval == 0:
                            self.update(batch_size)

                        if self.t % clone_interval == 0:
                            self.clone()

                        if self.t % save_interval == 0:
                            self.save(self.t)

                    if self.t % 1000 == 0:
                        progress_bar.set_description("t = {}".format(self.t))

                metadata['episode'].append(episode)
                metadata['reward'].append(total_reward.cpu().item())

                if episode % 50 == 0 and episode != 0:
                    avg_return = np.mean(metadata['reward'][-50:])
                    print("Average return (last 50 episodes): {:.2f}".format(avg_return), "Eps: {:.2f}".format(self.exploration_rate(self.t - learning_start)))

                if plot:
                    plt.scatter(metadata['episode'], metadata['reward'])
                    plt.xlim(0, episodes)
                    plt.xlabel("Episode")
                    plt.ylabel("Return")
                    display.clear_output(wait=True)
                    display.display(plt.gcf())

            env.close()
            return metadata

        except KeyboardInterrupt:
            if train:
                print("Saving model before quitting...")
                self.save(self.t)

            env.close()
            return metadata

In [40]:
env = gym.make('PongDeterministic-v4', obs_type="grayscale")

In [46]:
# Hyperparameters

batch_size = 32
update_interval = 4

clone_interval = int(1e4)
save_interval = int(1e5)

learning_start = 0
num_frames = 4
num_actions = env.action_space.n
episodes = int(1e4)
buffer_size = int(1e5)
epsilon_i = 1.0
epsilon_f = 0.001
anneal_time = int(1e6)
gamma = 0.99
learning_rate = 1e-4

In [47]:
model = DeepQNetwork(num_frames, num_actions)

In [48]:
agent = Agent(model, buffer_size, learning_rate, gamma, epsilon_i, epsilon_f, anneal_time, 'ckpt')

In [None]:
metadata = agent.learn(
    env=env,
    train=True,
    load=False,
    episodes=episodes,
    num_frames=num_frames,
    num_actions=num_actions,
    batch_size=batch_size,
    update_interval=update_interval,
    clone_interval=clone_interval,
    save_interval=save_interval,
    learning_start=learning_start,
)