In [1]:
%%capture
!pip install gym pyvirtualdisplay -qq
!pip install folium==0.2.1
!apt-get install -y xvfb python-opengl ffmpeg -qq

!apt-get update -qq
!apt-get install cmake -qq
!pip install --upgrade setuptools -qq
!pip install ez_setup -qq

import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import cv2 as cv
import numpy as np

In [2]:
%%capture
import urllib.request
urllib.request.urlretrieve('http://www.atarimania.com/roms/Roms.rar','Roms.rar')
!pip install unrar
!unrar x Roms.rar
!mkdir rars
!mv HC\ ROMS.zip   rars
!mv ROMS.zip  rars
!python -m atari_py.import_roms rars

In [19]:
# image preprocessing
def process_img(s):
  gray = (0.2989 * s[:,:,0] + 0.5870 * s[:,:,1] + 0.1140 * s[:,:,2])
  gray = gray[25:195,:]
  out_img = cv.resize(gray, dsize =(84,84), interpolation=cv.INTER_CUBIC)
  out_img = np.reshape(out_img, (1,84,84))
  return out_img

In [42]:
class Qnet(nn.Module):
    def __init__(self, n_action):
        super(Qnet, self).__init__()
        hid_dim1 = 32
        hid_dim2 = 64
        hid_dim3 = 30
        self.conv_out_shape = 0
        self.conv_layer = nn.Sequential(
          nn.Conv2d(1, hid_dim1, kernel_size=3,stride=4),
          nn.BatchNorm2d(hid_dim1),
          nn.ReLU(),
          nn.Conv2d(hid_dim1, hid_dim2, kernel_size=4, stride=2),
          nn.BatchNorm2d(hid_dim2),
          nn.ReLU(),
          nn.Flatten(),
        )

        # print("conv out shape is: ", self.conv_out_shape)
        self.linear_layer = nn.Sequential(
          nn.Linear(5184, hid_dim3),
          nn.ReLU(),
          nn.Linear(hid_dim3, n_action),
          nn.Softmax(dim=1)
        )

    def forward(self, x):
        # x = process_img(x)
        # x = torch.from_numpy(x)
        out = x.float()
        if len(x.shape) == 3:
          out = np.reshape(out, (1,1,84,84))
        # print(out.shape)
        out = self.conv_layer(out)
        out = self.linear_layer(out)
        return out
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else : 
            return out.argmax().item()

In [49]:
#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 10
max_step = 30000

class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)
            
def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)
        # print("s:")
        # print(s.shape)
        # print(s)
        q_out = q(s)
        # print("q_out:")
        # print(q_out)
        q_a = q_out.gather(1,a)
        # print("q_a")
        # print(q_a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        # print("max_q_prime")
        # print(max_q_prime)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def main():
    env = gym.make("SpaceInvaders-v0")
    n_action = env.action_space.n
    q = Qnet(n_action)
    q_target = Qnet(n_action)
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    print_interval = 20
    score = 0.0  
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    for n_epi in range(1000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s = env.reset()
        s = process_img(s)
        done = False
        step = 0

        while step < max_step:
            step+=1
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)      
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            s_prime = process_img(s_prime)
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break
            
        if memory.size()>2000:
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            score = 0.0
    env.close()
    return q

In [None]:
# train the model
q = main()

In [63]:
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1



In [64]:
import gym
from gym.wrappers import Monitor
import glob
import io
import base64
from IPython.display import HTML
from pyvirtualdisplay import Display
from IPython import display as ipythondisplay

display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment 
and displaying it.
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [66]:
# Play the model
env = wrap_env(gym.make("SpaceInvaders-v0"))

observation = env.reset()

for i in range(50):
  env.reset()
  while True:
    
      env.render()
      
      #your agent goes here
      action = torch.argmax(q(torch.from_numpy(process_img(s))))
      s, r, done, info = env.step(action)
          
      if done: 
        break;
            
env.close()
show_video()