In [68]:
from torch import nn, tensor
import torch.nn.functional as F
import gym
import time 
import numpy as np
import torch
from tqdm import tqdm_notebook
import random
from collections import namedtuple
import math

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [3]:
act_fn = nn.ReLU(inplace=True)

def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
    bn = nn.BatchNorm2d(nf)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers = [conv(ni, nf, ks, stride=stride), bn]
    if act: layers.append(act_fn)
    return nn.Sequential(*layers)

class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

In [34]:
class Q_model(nn.Module) :
    """
    model takes 2 frames as input uses same cnn on them 
    """
    
    def __init__(self, cnn, cnn_out_sz, n_actions) :
        super(Q_model, self).__init__()
        self.cnn = cnn
        self.lin = nn.Linear(2*cnn_out_sz, n_actions)
    
    def forward(self, frame1, frame2) :
        out1 = self.cnn(frame1)
        out2 = self.cnn(frame2)
        
        linear_input = torch.cat([out1, out2], 1)
        action_vals = self.lin(linear_input)
        return action_vals

In [70]:
env = gym.make('Breakout-v0')

In [71]:
env.observation_space, env.action_space

(Box(210, 160, 3), Discrete(4))

In [108]:
layers = [conv_layer(3, 16, stride=2), conv_layer(16, 32, stride=2), conv_layer(32, 128, stride=2), 
          nn.AvgPool2d(4), Flatten()]
cnn = nn.Sequential(*layers)

q_mod = Q_model(cnn, 3840, 4)
q_mod_target = Q_model(cnn, 3840, 4)
q_mod_target.load_state_dict(q_mod.state_dict())

it_bf = 50 #iterations before updating the target network
q_mod.cuda()
q_mod_target.cuda()

opt = torch.optim.Adam(q_mod.parameters())

discount = 0.1

replay_mem = ReplayMemory(capacity=1000)

rewards = []

#times = []
batch_size = 64

count = 0
p_init = 0.5 # init probability of exploration
p_final = 0.01 # final probability of exploration
decay = 0.001

for i_episode in tqdm_notebook(range(300)):
    last_observation = tensor(env.reset(), dtype=torch.float32).view(3, 160, -1).cuda()
    acc_reward = 0
    theta = 0.1
    total_iterations = 1000
    
    for t in range(total_iterations):
        
        if t == 0 :
            observation, reward, done, info = env.step(env.action_space.sample())
            observation = tensor(observation, dtype=torch.float32).view(3, 160, -1).cuda()
    
        count += 1
        if count % it_bf :
            q_mod_target.load_state_dict(q_mod.state_dict())
        
        env.render()

        prob_exploration = p_init + (p_final - p_init) * math.exp(-1 * decay * (i_episode+1) * t) #prob of exploration
        
        if np.random.rand() < prob_exploration :
            action = tensor([env.action_space.sample()]).cuda()
        else :
            inp = (last_observation.unsqueeze(0).cuda(), observation.unsqueeze(0).cuda())
            action = q_mod(*inp).argmax(1).cuda()


        new_observation, reward, done, info = env.step(action)
        new_observation = tensor(new_observation, dtype=torch.float32).view(3, 160, -1).cuda()
        

        replay_mem.push((last_observation, observation), action, (observation, new_observation), tensor([[reward]]).cuda())
        
        #training
        transitions = replay_mem.sample(min(batch_size, len(replay_mem)))
        model_input = (torch.cat([frame[0][0].unsqueeze(0) for frame in transitions]),
                       torch.cat([frame[0][1].unsqueeze(0) for frame in transitions]))
        action_batch = torch.cat([a[1] for a in transitions]).unsqueeze(1)
        reward_batch = torch.cat([a[3] for a in transitions]).cuda()
        model_target_input = (torch.cat([frame[2][0].unsqueeze(0) for frame in transitions]),
                               torch.cat([frame[2][1].unsqueeze(0) for frame in transitions]))
        preds = q_mod(*model_input).gather(1, action_batch)
        targets = q_mod_target(*model_target_input).max(1)[0].detach().unsqueeze(1)
        targets = targets* 0.1  + reward_batch
        loss = F.smooth_l1_loss(preds, targets)

        #time_b = time.time()

        opt.zero_grad()
        loss.backward()
        #time_a= time.time()

        #times.append(time_a-time_b)
        for param in q_mod.parameters():
            param.grad.data.clamp_(-1, 1)
        opt.step()
        
        acc_reward += reward
        
        last_observation = observation
        observation = new_observation
        
        if done:
            print("Episode finished after {} timesteps, acc_reward {}".format(t+1, acc_reward))
            rewards.append(acc_reward)
            break

env.close()


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Episode finished after 246 timesteps, acc_reward 2.0
Episode finished after 298 timesteps, acc_reward 2.0
Episode finished after 305 timesteps, acc_reward 2.0
Episode finished after 247 timesteps, acc_reward 1.0
Episode finished after 203 timesteps, acc_reward 0.0
Episode finished after 172 timesteps, acc_reward 0.0
Episode finished after 196 timesteps, acc_reward 0.0
Episode finished after 249 timesteps, acc_reward 2.0
Episode finished after 255 timesteps, acc_reward 1.0
Episode finished after 185 timesteps, acc_reward 0.0
Episode finished after 181 timesteps, acc_reward 0.0
Episode finished after 191 timesteps, acc_reward 0.0
Episode finished after 327 timesteps, acc_reward 2.0
Episode finished after 184 timesteps, acc_reward 0.0
Episode finished after 209 timesteps, acc_reward 1.0
Episode finished after 277 timesteps, acc_reward 2.0
Episode finished after 186 timesteps, acc_reward 0.0
Episode finished after 279 timesteps, acc_reward 1.0
Episode finished after 235 timesteps, acc_rewa

Episode finished after 190 timesteps, acc_reward 0.0
Episode finished after 250 timesteps, acc_reward 1.0
Episode finished after 434 timesteps, acc_reward 4.0
Episode finished after 182 timesteps, acc_reward 0.0
Episode finished after 218 timesteps, acc_reward 1.0
Episode finished after 203 timesteps, acc_reward 1.0
Episode finished after 190 timesteps, acc_reward 0.0
Episode finished after 202 timesteps, acc_reward 0.0
Episode finished after 182 timesteps, acc_reward 0.0
Episode finished after 192 timesteps, acc_reward 0.0
Episode finished after 186 timesteps, acc_reward 0.0
Episode finished after 243 timesteps, acc_reward 1.0
Episode finished after 175 timesteps, acc_reward 0.0
Episode finished after 173 timesteps, acc_reward 0.0
Episode finished after 303 timesteps, acc_reward 2.0
Episode finished after 241 timesteps, acc_reward 1.0
Episode finished after 241 timesteps, acc_reward 1.0
Episode finished after 197 timesteps, acc_reward 1.0
Episode finished after 424 timesteps, acc_rewa