In [1]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [2]:
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1

Requirement already up-to-date: setuptools in /usr/local/lib/python3.7/dist-packages (56.0.0)


In [25]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

In [38]:
import math
import glob
import io
import base64
from IPython.display import HTML

from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only

In [39]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7f4c6dc7f2d0>

In [49]:
"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
env = gym.make("Pong-v0")
env = wrap_env(env)
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

cuda


In [51]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [52]:
class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        
        def flat(size, kernel_size, stride):
            return ((size - kernel_size) // stride) + 1
        
        def convBlock(in_channels, filters, *args, **kwargs):
            return nn.Sequential(
                nn.Conv2d(in_channels, filters, *args, **kwargs),
                nn.BatchNorm2d(filters),
                nn.LeakyReLU()
            )
        
        def linBlock(inDim, outDim):
            return nn.Sequential(nn.Linear(inDim, outDim),
                                 #nn.BatchNorm2d(1),
                                 nn.LeakyReLU())

        self.conv1 = convBlock(1, 32, kernel_size=8, stride=4)
        self.conv2 = convBlock(32, 64, kernel_size=4, stride=2)
        self.conv3 = convBlock(64, 64, kernel_size=3, stride=1)

        self.convw = flat(flat(flat(w, 8, 4), 4, 2), 3, 1)
        self.convh = flat(flat(flat(h, 8, 4), 4, 2), 3, 1)

        self.fc1 = linBlock(self.convw * self.convh * 64, 512)
        self.fc2 = nn.Linear(512, outputs)

    def forward(self, x):
        #print('shape before conv:', x.shape)
        out = self.conv1(x)
        #print('shape after 1 conv:', out.shape)
        out = self.conv2(out)
        #print('shape after 2 conv:', out.shape)
        out = self.conv3(out)
        #print('shape after 3 conv:', out.shape)
        # [128, ....., , ]
        out = out.view(out.size(0), -1)
        add = 4 - len(out.shape)
        for _ in range(add):
          out = out.unsqueeze(1)
        #print('shape before input to linear layer', out.shape)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [53]:
resize = T.Compose([T.ToPILImage(), 
                    T.Grayscale(), 
                    T.Resize([84, 84], interpolation=InterpolationMode.NEAREST), 
                    T.ToTensor()])
screen = resize(env.render(mode='rgb_array'))

In [54]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 1000

n_actions = env.action_space.n
screen_height = screen.shape[1]
screen_width = screen.shape[2]

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

steps_done = 0

In [55]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)),
        device=device,
        dtype=torch.bool,
    )
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    temp1 = policy_net(state_batch)
    temp1 = temp1.view(temp1.shape[0], temp1.shape[-1]) # 128 * 6
    # print('temp1 shhape', temp1.shape)
    state_action_values = temp1.gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    temp2 = target_net(non_final_next_states)
    temp2 = temp2.view(temp2.shape[0], temp2.shape[-1])
    next_state_values[non_final_mask] = (
        temp2.max(1)[0].detach()
    )
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(
        state_action_values, expected_state_action_values.unsqueeze(1)
    )

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [56]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        #print('go through policy net')
        with torch.no_grad():
            temp = policy_net(state)
            temp = temp.view(1, temp.shape[-1])
            return temp.max(1)[1].view(1, 1)
    else:
        #print('direct return')
        return torch.tensor([[random.randrange(n_actions)]], 
                            device=device, 
                            dtype=torch.long)

In [57]:
num_episodes = 10
episode_durations = []
rewards = []
rw = 0
counter = 0
mean_reward = -21
while mean_reward < 0.1:
    env.reset()
    last_screen = resize(env.render(mode='rgb_array')).unsqueeze(0).to(device)
    current_screen = resize(env.render(mode='rgb_array')).unsqueeze(0).to(device)
    state = current_screen - last_screen
    for t in count():
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        rw += reward.item()
        last_screen = current_screen
        current_screen = resize(env.render(mode='rgb_array')).unsqueeze(0).to(device)
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None

        memory.push(state, action, next_state, reward)

        state = next_state

        optimize_model()
        if len(rewards) > 40:
          mean_reward = np.mean(rewards[-40:])
        counter += 1
        if counter % 1000 == 0:
          print('counter', counter, 'eps', len(episode_durations))
          print('mean reward =', mean_reward)
        if done:
            episode_durations.append(t + 1)
            rewards.append(rw)
            rw = 0
            break
    if len(episode_durations) % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

print("Complete")
env.render()
env.close()

counter 1000 eps 0
mean reward = -21
counter 2000 eps 1
mean reward = -21
counter 3000 eps 2
mean reward = -21
counter 4000 eps 3
mean reward = -21
counter 5000 eps 4
mean reward = -21
counter 6000 eps 5
mean reward = -21
counter 7000 eps 6
mean reward = -21
counter 8000 eps 7
mean reward = -21
counter 9000 eps 7
mean reward = -21
counter 10000 eps 8
mean reward = -21
counter 11000 eps 9
mean reward = -21
counter 12000 eps 10
mean reward = -21
counter 13000 eps 11
mean reward = -21
counter 14000 eps 12
mean reward = -21
counter 15000 eps 13
mean reward = -21
counter 16000 eps 13
mean reward = -21
counter 17000 eps 14
mean reward = -21
counter 18000 eps 15
mean reward = -21
counter 19000 eps 16
mean reward = -21
counter 20000 eps 17
mean reward = -21
counter 21000 eps 18
mean reward = -21
counter 22000 eps 19
mean reward = -21
counter 23000 eps 19
mean reward = -21
counter 24000 eps 20
mean reward = -21
counter 25000 eps 21
mean reward = -21
counter 26000 eps 22
mean reward = -21
counte

KeyboardInterrupt: ignored

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
# x = [i for i in range(len(rewards))]
# plt.plot(x, rewards)
# plt.show()

print(rewards.count(0))
print(rewards.count(1))
print(rewards.count(-1))