In [1]:
!pip install gym[atari]

Collecting atari-py~=0.2.0; extra == "atari"
  Downloading atari_py-0.2.6-cp37-cp37m-manylinux1_x86_64.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 605 kB/s eta 0:00:01
Installing collected packages: atari-py
Successfully installed atari-py-0.2.6
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
import gym
from gym.wrappers import AtariPreprocessing, FrameStack
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
env = gym.make('BreakoutNoFrameskip-v0')

num_actions = env.action_space.n
state_shape = env.observation_space.shape

print('Actions: {} -- State space dimensions: {}'.format(num_actions, state_shape))

Actions: 4 -- State space dimensions: (210, 160, 3)


# Preprocessing

We can preprocess the frames using gym wrappers.

In [10]:
frames = 4

env = gym.make('BreakoutNoFrameskip-v0')

# Grayscale, frame resize and frame rescale
env = AtariPreprocessing(env, scale_obs=True)

# Frame stack 
env = FrameStack(env, frames)

# DQN

In [11]:
class BreakoutDQN(nn.Module):
    
    def __init__(self, num_actions, frame_h, frame_w, frame_stack=4):
        super(BreakoutDQN, self).__init__()
        
        self.conv1 = nn.Conv2d(frame_stack, 32, kernel_size=6, stride=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        def conv2d_size_out(size, kernel=5, stride=2) -> int:
            return (size - kernel) // stride + 1

        convh = conv2d_size_out(frame_h, kernel=6, stride=3)
        convh = conv2d_size_out(convh, kernel=4, stride=2)
        convh = conv2d_size_out(convh, kernel=4, stride=1)

        
        convw = conv2d_size_out(frame_w, kernel=6, stride=3)
        convw = conv2d_size_out(convw, kernel=4, stride=2)
        convw = conv2d_size_out(convw, kernel=4, stride=1)
        
        linear_input = convh * convw * 64
        
        self.fc1 = nn.Linear(linear_input, 256)
        self.head = nn.Linear(256, num_actions)
    
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = F.leaky_relu(self.fc1(x.view(x.size(0), -1)))
        x = self.head(x)
        return x

In [12]:
class ReplayBuffer():
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.position = 0
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.done = []
        
        
    def push(self, state, action, reward, next_state, done):
        if len(self.states) < self.capacity:
            self.states.append(None)
            self.actions.append(None)
            self.rewards.append(None)
            self.next_states.append(None)
            self.done.append(None)
            
        self.states[self.position] = state
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.next_states[self.position] = next_state
        self.done[self.position] = done
        
        self.position = (self.position + 1) % self.capacity
        
        
    def sample(self, batch_size):
        indices = np.random.choice(range(len(self.states)), size=batch_size)
        state_sample = [self.states[i] for i in indices]
        action_sample = [self.actions[i] for i in indices]
        reward_sample = [self.rewards[i] for i in indices]
        next_state_sample = [self.next_states[i] for i in indices]
        done_sample = [self.done[i] for i in indices]
        
        return state_sample, action_sample, reward_sample, next_state_sample, done_sample
        
    def __len__(self):
        return len(self.states)

In [15]:
def select_action(state, env, model, epsilon):
    if random.random() > epsilon:
        with torch.no_grad():
            return model(state).argmax().item()
    else:
        return env.action_space.sample()

In [16]:
def update_epsilon(epsilon_start, epsilon_end, epsilon_steps, total_steps):
    return epsilon_end + (epsilon_start - epsilon_end) * math.exp(-1. * total_steps / epsilon_steps)  

In [82]:
def optimize_model(policy_net, target_net, optimizer, memory, batch_size, gamma, frame_stack=4, frame_h=84, frame_w=84):
    state_batch, action_batch, reward_batch, next_state_batch, done_batch = memory.sample(batch_size)
        
    state_batch = torch.stack(state_batch).view((batch_size, frame_stack, frame_h, frame_w))
    action_batch = torch.stack(action_batch).view((batch_size))
    reward_batch = torch.stack(reward_batch).view((batch_size))
    
    non_final_next_states = torch.stack([s for s in next_state_batch if s is not None])
    non_final_mask = torch.tensor(list(map(lambda s: s is not None, next_state_batch)), dtype=torch.bool)
        
    state_action_values = policy_net(state_batch)
    state_action_values = state_action_values.gather(1, action_batch.reshape((batch_size, 1)))
    
    next_state_values = torch.zeros(batch_size, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(dim=1)[0].float().detach()
        
    expected_state_action_values = reward_batch + gamma * next_state_values
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss

def train_dqn(env, policy_net, target_net, optimizer, memory, frame_h=84, frame_w=84, frame_stack=4, target_update=10, batch_size=32, episodes=100, gamma=0.99, epsilon_start=0.9, epsilon_end=0.05, epsilon_steps=1000000):
    
    total_rewards = []
    total_steps = 0
    
    epsilon = epsilon_start
        
    for episode in range(episodes):
        
        done = False
        state = env.reset()
        
        total_rewards.append(0)
        loss = 0
        
        while not done:
            
            # env.render()
            state_tensor = torch.tensor(state, device=device).float().view(1, frame_stack, frame_h, frame_w) # TODO num frames, 
            action = select_action(state_tensor, env, policy_net, epsilon)
            
            next_state, reward, done, _ = env.step(action)
            
            total_rewards[episode] += reward
            
            action_tensor = torch.tensor(action, device=device, dtype=torch.int64)
            reward_tensor = torch.tensor(reward, device=device, dtype=torch.float)
            next_state = torch.tensor(next_state, device=device, dtype=torch.float)
            if done:
                next_state = None
                                             
            memory.push(state_tensor, action_tensor, reward_tensor, next_state, done)
            
            
            state = next_state
            
            if len(memory) >= batch_size:                 
                loss = optimize_model(policy_net, target_net, optimizer, memory, batch_size, gamma)
        
            if total_steps % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())
                
            total_steps += 1
            epsilon = update_epsilon(epsilon_start, epsilon_end, epsilon_steps, total_steps)
            
        
        print('{}/{} Total steps: {} Episode reward: {} Average reward: {} Loss: {} Epsilon: {}'.format(episode, episodes, total_steps, total_rewards[episode], np.mean(total_rewards), loss, epsilon))   
                   
    

In [83]:
target_net = BreakoutDQN(num_actions, 84, 84).to(device)
policy_net = BreakoutDQN(num_actions, 84, 84).to(device)

#policy_net.load_state_dict(target_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters())

memory = ReplayBuffer(100000)
try:
    train_dqn(env, policy_net, target_net, optimizer, memory, gamma=0.5, batch_size=128, episodes=1000, epsilon_steps=100000, epsilon_end=0.05)
finally:
    env.close()



1.61 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


KeyboardInterrupt: 