In [None]:
!pip install gym[atari]

In [2]:
import gym
from gym.wrappers import AtariPreprocessing, FrameStack
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
from collections import namedtuple

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
env = gym.make('BreakoutNoFrameskip-v0')

num_actions = env.action_space.n
state_shape = env.observation_space.shape

print('Actions: {} -- State space dimensions: {}'.format(num_actions, state_shape))

Actions: 4 -- State space dimensions: (210, 160, 3)


# Preprocessing

We can preprocess the frames using gym wrappers.

In [5]:
frames = 4

env = gym.make('BreakoutNoFrameskip-v0')

# Grayscale, frame resize and frame rescale
env = AtariPreprocessing(env, scale_obs=True)

# Frame stack 
env = FrameStack(env, frames)

# DQN

In [6]:
class BreakoutDQN(nn.Module):
    
    def __init__(self, num_actions, frame_h, frame_w, frame_stack=4):
        super(BreakoutDQN, self).__init__()
        
        self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=6, stride=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        def conv2d_size_out(size, kernel=5, stride=2) -> int:
            return (size - kernel) // stride + 1

        convh = conv2d_size_out(frame_h, kernel=6, stride=3)
        convh = conv2d_size_out(convh, kernel=4, stride=2)
        convh = conv2d_size_out(convh, kernel=4, stride=1)

        
        convw = conv2d_size_out(frame_w, kernel=6, stride=3)
        convw = conv2d_size_out(convw, kernel=4, stride=2)
        convw = conv2d_size_out(convw, kernel=4, stride=1)
        
        linear_input = convh * convw * 64
        
        self.fc1 = nn.Linear(linear_input, 256)
        self.head = nn.Linear(256, num_actions)
    
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = F.leaky_relu(self.fc1(x.view(x.size(0), -1)))
        x = self.head(x) * 15
        return x

In [7]:
Transition = namedtuple("Transition", ("state", "action", "reward", "next_state"))

class ReplayBuffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        state_sample, action_sample, reward_sample, next_state_sample = Transition(*zip(*random.sample(self.memory, batch_size)))
        return torch.stack(state_sample), \
               torch.stack(action_sample), \
               torch.stack(reward_sample), \
               next_state_sample
               
    
    def __len__(self):
        return len(self.memory)

In [8]:
def select_action(state, env, model, epsilon):
    if random.random() > epsilon:
        with torch.no_grad():
            return model(state).argmax().item()
    else:
        return env.action_space.sample()

In [9]:
def update_epsilon(epsilon_start, epsilon_end, epsilon_steps, total_steps):
    return epsilon_end + (epsilon_start - epsilon_end) * math.exp(-1. * total_steps / epsilon_steps)  

In [16]:
def optimize_model(policy_net, target_net, optimizer, memory, batch_size, gamma, frame_stack=4, frame_h=84, frame_w=84):
    state_batch, action_batch, reward_batch, next_state_batch = memory.sample(batch_size)
    
    state_batch = state_batch.view((batch_size, frame_stack, frame_h, frame_w))
    
    non_final_next_states = torch.cat([s for s in next_state_batch if s is not None])
    non_final_mask = torch.tensor(list(map(lambda s: s is not None, next_state_batch)), dtype=torch.bool)
        
    state_action_values = policy_net(state_batch)
    state_action_values = state_action_values.gather(1, action_batch.reshape((batch_size, 1)))
    
    next_state_values = torch.zeros(batch_size, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(dim=1)[0].float().detach()
        
    expected_state_action_values = reward_batch + gamma * next_state_values
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss

def train_dqn(env, policy_net, target_net, optimizer, memory, frame_h=84, frame_w=84, frame_stack=4, target_update=10, batch_size=32, episodes=100, gamma=0.99, epsilon_start=0.9, epsilon_end=0.05, epsilon_steps=1000000):
    
    total_rewards = []
    total_steps = 0
    
    epsilon = epsilon_start
        
    for episode in range(episodes):
        
        done = False
        state = torch.tensor(env.reset(), device=device).float().view(1, frame_stack, frame_h, frame_w)
        
        total_rewards.append(0)
        loss = 0
        
        while not done:
                        
            # env.render()
            action = select_action(state, env, policy_net, epsilon)
            
            next_state, reward, done, _ = env.step(action)
            
            total_rewards[episode] += reward
            
            action_tensor = torch.tensor(action, device=device, dtype=torch.int64)
            reward_tensor = torch.tensor(reward, device=device, dtype=torch.float)
            next_state = torch.tensor(next_state, device=device, dtype=torch.float).view(1, frame_stack, frame_h, frame_w)
            if done:
                next_state = None
                                            
            memory.push(state, action_tensor, reward_tensor, next_state)
            
            
            state = next_state
            
            if len(memory) >= batch_size:                 
                loss = optimize_model(policy_net, target_net, optimizer, memory, batch_size, gamma)
        
            if total_steps % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())
                
            total_steps += 1
            epsilon = update_epsilon(epsilon_start, epsilon_end, epsilon_steps, total_steps)
            
        
        print('{}/{} Total steps: {} Episode reward: {} Average reward: {} Loss: {} Epsilon: {}'.format(episode, episodes, total_steps, total_rewards[episode], np.mean(total_rewards), loss, epsilon))   
                   
    

In [None]:
target_net = BreakoutDQN(num_actions, 84, 84).to(device)
policy_net = BreakoutDQN(num_actions, 84, 84).to(device)

#policy_net.load_state_dict(target_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters())

memory = ReplayBuffer(100000)
try:
    train_dqn(env, policy_net, target_net, optimizer, memory, gamma=0.9, batch_size=32, episodes=1000, epsilon_steps=100000, epsilon_end=0.05)
finally:
    env.close()

0/1000 Total steps: 261 Episode reward: 3.0 Average reward: 3.0 Loss: 0.017512204125523567 Epsilon: 0.8977843926253687
1/1000 Total steps: 493 Episode reward: 2.0 Average reward: 2.5 Loss: 0.001796769560314715 Epsilon: 0.8958198126284538
2/1000 Total steps: 627 Episode reward: 0.0 Average reward: 1.6666666666666667 Loss: 0.0005694336141459644 Epsilon: 0.8946871731174847
3/1000 Total steps: 876 Episode reward: 3.0 Average reward: 2.0 Loss: 0.0027269867714494467 Epsilon: 0.8925865184568302
4/1000 Total steps: 1160 Episode reward: 4.0 Average reward: 2.4 Loss: 0.0007982361712493002 Epsilon: 0.8901969675128499
5/1000 Total steps: 1457 Episode reward: 4.0 Average reward: 2.6666666666666665 Loss: 0.0018873510416597128 Epsilon: 0.8877052845001843
6/1000 Total steps: 1614 Episode reward: 1.0 Average reward: 2.4285714285714284 Loss: 0.020840086042881012 Epsilon: 0.8863911190933038
7/1000 Total steps: 1800 Episode reward: 1.0 Average reward: 2.25 Loss: 0.002523174975067377 Epsilon: 0.88483687750

62/1000 Total steps: 12064 Episode reward: 1.0 Average reward: 1.380952380952381 Loss: 0.015500007197260857 Epsilon: 0.803400040854187
63/1000 Total steps: 12192 Episode reward: 0.0 Average reward: 1.359375 Loss: 6.808726175222546e-05 Epsilon: 0.8024363057239591
64/1000 Total steps: 12346 Episode reward: 1.0 Average reward: 1.353846153846154 Loss: 0.00035163311986252666 Epsilon: 0.8012784455942757
65/1000 Total steps: 12601 Episode reward: 3.0 Average reward: 1.378787878787879 Loss: 0.015256588347256184 Epsilon: 0.7993651260771745
66/1000 Total steps: 12755 Episode reward: 0.0 Average reward: 1.3582089552238805 Loss: 0.01586870476603508 Epsilon: 0.7982119919242111
67/1000 Total steps: 12890 Episode reward: 0.0 Average reward: 1.338235294117647 Loss: 3.5427932743914425e-05 Epsilon: 0.797202587236581
68/1000 Total steps: 13068 Episode reward: 1.0 Average reward: 1.3333333333333333 Loss: 2.7581034373724833e-05 Epsilon: 0.7958737496476115
69/1000 Total steps: 13296 Episode reward: 2.0 Aver

122/1000 Total steps: 23631 Episode reward: 1.0 Average reward: 1.3739837398373984 Loss: 0.0002495630760677159 Epsilon: 0.7211054978885854
123/1000 Total steps: 23849 Episode reward: 2.0 Average reward: 1.3790322580645162 Loss: 0.00040044443449005485 Epsilon: 0.7196440814259021
124/1000 Total steps: 24123 Episode reward: 3.0 Average reward: 1.392 Loss: 0.0005169283249415457 Epsilon: 0.7178117680584647
125/1000 Total steps: 24259 Episode reward: 0.0 Average reward: 1.380952380952381 Loss: 0.0009648395935073495 Epsilon: 0.7169041613663483
126/1000 Total steps: 24476 Episode reward: 2.0 Average reward: 1.3858267716535433 Loss: 0.0012485700426623225 Epsilon: 0.7154585483935294
127/1000 Total steps: 24650 Episode reward: 1.0 Average reward: 1.3828125 Loss: 0.00012436654651537538 Epsilon: 0.714301657306454
128/1000 Total steps: 24835 Episode reward: 1.0 Average reward: 1.37984496124031 Loss: 7.765452755847946e-05 Epsilon: 0.713073835325954
129/1000 Total steps: 25040 Episode reward: 2.0 Aver

182/1000 Total steps: 35648 Episode reward: 3.0 Average reward: 1.4316939890710383 Loss: 0.000917077821213752 Epsilon: 0.6451160029502144
183/1000 Total steps: 35785 Episode reward: 0.0 Average reward: 1.423913043478261 Loss: 8.23110094643198e-05 Epsilon: 0.6443012522578307
184/1000 Total steps: 35971 Episode reward: 1.0 Average reward: 1.4216216216216215 Loss: 0.0011378306662663817 Epsilon: 0.6431968793138597
185/1000 Total steps: 36160 Episode reward: 1.0 Average reward: 1.4193548387096775 Loss: 0.015347982756793499 Epsilon: 0.6420767960240862
186/1000 Total steps: 36418 Episode reward: 3.0 Average reward: 1.427807486631016 Loss: 5.8239675126969814e-05 Epsilon: 0.640551206746756
187/1000 Total steps: 36607 Episode reward: 1.0 Average reward: 1.425531914893617 Loss: 0.016347568482160568 Epsilon: 0.6394361190558063
188/1000 Total steps: 36750 Episode reward: 0.0 Average reward: 1.417989417989418 Loss: 1.646095552132465e-05 Epsilon: 0.638593827787347
189/1000 Total steps: 36958 Episode 

242/1000 Total steps: 47501 Episode reward: 3.0 Average reward: 1.4567901234567902 Loss: 1.6943979062489234e-05 Epsilon: 0.5785970119987173
243/1000 Total steps: 47627 Episode reward: 0.0 Average reward: 1.4508196721311475 Loss: 0.015435711480677128 Epsilon: 0.5779313991877303
244/1000 Total steps: 47845 Episode reward: 2.0 Average reward: 1.453061224489796 Loss: 0.0001501801743870601 Epsilon: 0.5767817622970064
245/1000 Total steps: 48054 Episode reward: 2.0 Average reward: 1.4552845528455285 Loss: 0.00010929422569461167 Epsilon: 0.5756819381304049
246/1000 Total steps: 48194 Episode reward: 0.0 Average reward: 1.4493927125506072 Loss: 6.32812298135832e-05 Epsilon: 0.5749464983449939
247/1000 Total steps: 48357 Episode reward: 0.0 Average reward: 1.4435483870967742 Loss: 3.05540525005199e-05 Epsilon: 0.5740915325391199
248/1000 Total steps: 48601 Episode reward: 2.0 Average reward: 1.4457831325301205 Loss: 6.476070120697841e-05 Epsilon: 0.5728143080472782
249/1000 Total steps: 48786 E

302/1000 Total steps: 59318 Episode reward: 3.0 Average reward: 1.4587458745874586 Loss: 0.030636392533779144 Epsilon: 0.519682225221561
303/1000 Total steps: 59467 Episode reward: 0.0 Average reward: 1.4539473684210527 Loss: 0.00029597265529446304 Epsilon: 0.518982919817884
304/1000 Total steps: 59614 Episode reward: 0.0 Average reward: 1.4491803278688524 Loss: 5.994707316858694e-05 Epsilon: 0.5182940213901495
305/1000 Total steps: 59972 Episode reward: 5.0 Average reward: 1.4607843137254901 Loss: 8.112948125926778e-05 Epsilon: 0.5166205261374234
306/1000 Total steps: 60103 Episode reward: 0.0 Average reward: 1.4560260586319218 Loss: 0.015269852243363857 Epsilon: 0.5160096534571489
307/1000 Total steps: 60315 Episode reward: 2.0 Average reward: 1.4577922077922079 Loss: 0.0005281114717945457 Epsilon: 0.5150227594690717
308/1000 Total steps: 60487 Episode reward: 1.0 Average reward: 1.4563106796116505 Loss: 0.00011990615166723728 Epsilon: 0.5142236077902462
309/1000 Total steps: 60618 E

362/1000 Total steps: 70691 Episode reward: 1.0 Average reward: 1.4269972451790633 Loss: 0.00014020627713762224 Epsilon: 0.4691908684468395
363/1000 Total steps: 70864 Episode reward: 1.0 Average reward: 1.4258241758241759 Loss: 0.0007776610436849296 Epsilon: 0.46846629518101596
364/1000 Total steps: 71039 Episode reward: 1.0 Average reward: 1.4246575342465753 Loss: 1.6588328435318545e-05 Epsilon: 0.4677346195673409
365/1000 Total steps: 71227 Episode reward: 1.0 Average reward: 1.4234972677595628 Loss: 0.00026355450972914696 Epsilon: 0.4669500162407731
366/1000 Total steps: 71378 Episode reward: 0.0 Average reward: 1.4196185286103542 Loss: 0.031210558488965034 Epsilon: 0.46632089682094946
367/1000 Total steps: 71640 Episode reward: 3.0 Average reward: 1.423913043478261 Loss: 0.0004244451702106744 Epsilon: 0.4652315637207745
368/1000 Total steps: 71779 Episode reward: 0.0 Average reward: 1.4200542005420054 Loss: 0.00012562649499159306 Epsilon: 0.46465479279586036
369/1000 Total steps: 

422/1000 Total steps: 83001 Episode reward: 1.0 Average reward: 1.4373522458628842 Loss: 0.00045226470683701336 Epsilon: 0.42063818697290356
423/1000 Total steps: 83165 Episode reward: 0.0 Average reward: 1.4339622641509433 Loss: 0.00010023333743447438 Epsilon: 0.4200308385081361
424/1000 Total steps: 83498 Episode reward: 4.0 Average reward: 1.44 Loss: 0.00014970141637604684 Epsilon: 0.41880068515798574
425/1000 Total steps: 83708 Episode reward: 2.0 Average reward: 1.4413145539906103 Loss: 0.000366375083103776 Epsilon: 0.41802701635571965
426/1000 Total steps: 83891 Episode reward: 1.0 Average reward: 1.440281030444965 Loss: 0.01598259061574936 Epsilon: 0.41735414278288996
427/1000 Total steps: 84027 Episode reward: 0.0 Average reward: 1.4369158878504673 Loss: 3.036432826775126e-05 Epsilon: 0.41685488072385835
428/1000 Total steps: 84204 Episode reward: 1.0 Average reward: 1.435897435897436 Loss: 3.5496021155267954e-05 Epsilon: 0.41620612190590567
429/1000 Total steps: 84337 Episode 