In [None]:
%matplotlib inline

import argparse
from collections import namedtuple
from copy import deepcopy
from itertools import count
import math
import random
import time

from IPython import display
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms


plt.ion()
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
parser = argparse.ArgumentParser()

# data
parser.add_argument('--resize_w', type=int, default=40)
parser.add_argument('--resize_h', type=int, default=40)
# model
parser.add_argument('--replay_memory_size', type=int, default=10000)
# training
parser.add_argument('--n_episodes', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--gamma', type=float, default=0.999)
parser.add_argument('--eps_start', type=float, default=0.9)
parser.add_argument('--eps_end', type=float, default=0.05)
parser.add_argument('--eps_decay', type=int, default=200)
parser.add_argument('--update_target_net_every_x_episodes', type=int, default=10)


args, _ = parser.parse_known_args()
args.n_episodes = 500

In [None]:
def preprocess(observation):
    observation = observation[35:195]
    # 2x downsample
    observation = observation[::2,::2,0]
    # erase background (background type 1)
    observation[observation == 144] = 0
    # erase background (background type 2)
    observation[observation == 109] = 0
    # everything else (paddles, ball) just set to 1
    observation[observation != 0] = 1

    return observation.astype(np.float32)


episode_durations = []

def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
    
    plt.pause(0.001)
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [None]:
env = gym.make('Pong-v0').unwrapped

observation = env.reset()
plt.figure(figsize=(2 * 4, 3))

plt.subplot(1, 2, 1)
plt.imshow(observation)
plt.subplot(1, 2, 2)
plt.imshow(preprocess(observation), cmap='gray')
plt.show()

In [None]:
env = gym.make('Pong-v0').unwrapped

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
        
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)

        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, h, w, n_outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        
        def calc_feature_map_size(size, kernel_size=5, stride=2):
            return (size - kernel_size) // stride + 1
        
        n_conv_layers = 3
        feat_w, feat_h = w, h
        for _ in range(n_conv_layers):
            feat_w = calc_feature_map_size(feat_w)
            feat_h = calc_feature_map_size(feat_h)
        
        self.fc = nn.Linear(feat_w * feat_h * 32, n_outputs)
        
    def forward(self, x):
        bs = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(bs, -1)
        x = self.fc(x)
        return x

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((args.resize_w, args.resize_h), interpolation=Image.CUBIC),
    transforms.ToTensor(),
])

In [None]:
def get_observation(obs):
    obs = preprocess(obs)
    obs = torch.from_numpy(obs)
    # (bs, c, h, w)
    obs = transform(obs).unsqueeze(0).to(device)
    return obs

env.reset()
plt.figure()
_im = get_screen(env).cpu().squeeze().permute(1, 2, 0).numpy()
plt.imshow(_im, interpolation='none')
plt.show()

In [None]:
init_screen = get_observation(env.reset())
_, _, screen_h, screen_w = init_screen.shape

_action_space = [0, 2, 3]
n_actions = len(_action_space)  # env.action_space.n

# model
policy_net = DQN(screen_h, screen_w, n_actions).to(device)
target_net = DQN(screen_h, screen_w, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
# optim
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(args.replay_memory_size)

In [None]:
steps_done = 0

def select_action(state, args):
    global steps_done

    sample = random.random()
    eps_threshold = args.eps_end + (args.eps_start - args.eps_end) * math.exp(-1. * steps_done / args.eps_decay)
    steps_done += 1
    
    if sample > eps_threshold:
        with torch.no_grad():
            action = policy_net(state).max(dim=1)[1].view(1, 1)
            return action
    else:
        action = torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
        return action

In [None]:
def train(args):
    if len(memory) < args.batch_size:
        return
    
    transitions = memory.sample(args.batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)),
        device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    # import pdb;pdb.set_trace()
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(args.batch_size, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = reward_batch + args.gamma * next_state_values
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
total_rewards = []
durations = []
for episode_i in range(args.n_episodes):
    g = 0
    
    last_obs = get_observation(env.reset())
    current_obs = get_observation(env.reset())
    state = current_obs - last_obs
    
    for t in count():
        env.render()
        action = select_action(state, args)
        _action = _action_space[action.item()]
        obs, reward, done, _ = env.step(_action)
        g += reward
        reward = torch.tensor([reward], device=device)
        
        last_obs = current_obs
        current_obs = get_observation(obs)
        if done:
            next_state = None
        else:
            next_state = current_obs - last_obs

        memory.push(state, action, next_state, reward)
        state = next_state
        
        train(args)
        if done:
            episode_durations.append(t + 1)
            # plot_durations()
            break
    if (episode_i + 1) % args.update_target_net_every_x_episodes == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
    total_rewards.append(g)
    durations.append(t + 1)
    print('Episode {} - Return {} Duration {}'.format(episode_i + 1, g, t + 1))

print('Finished')
env.render()
env.close()
plt.ioff()
plt.show()

In [None]:
assert len(total_rewards) == len(durations) == args.n_episodes

In [None]:
plt.figure(figsize=(2 * 6, 4))
plt.subplot(1, 2, 1)
plt.plot(total_rewards)
plt.title('Toral rewards')

plt.subplot(1, 2, 2)
plt.plot(durations)
plt.title('Durations')
plt.show()