In [14]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T

env = gym.make('Breakout-v0').unwrapped

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

In [15]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Положить переход в память."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        """ Получить сэмпл из памяти """
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [16]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(6, 16, kernel_size=4, stride=2)
        #self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
        #self.bn2 = nn.BatchNorm2d(32)
        self.head = nn.Linear(896, 4)

    def forward(self, x):
        #x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.conv1(x))
        #x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(F.relu(self.conv2(x)), 5)
        #x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

In [17]:
resize = T.Compose([T.ToPILImage(),
#                    T.Grayscale(),
                    T.ToTensor()])

def get_screen():
    # транспонирование в порядок торча (СHW)
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    screen = screen[:, 92:-17, 7:-7]

    # Конвертируем в торч тензор
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Ресайзим и добавляем батч размерность
    return resize(screen).unsqueeze(0).type(Tensor)

In [18]:
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.99
EPS_END = 0.05
EPS_DECAY = 200

model = DQN()
#cp_dic = torch.load("DQN_model.pth")

if use_cuda:
    model.cuda()

optimizer = optim.RMSprop(model.parameters(), lr=0.0001)
memory = ReplayMemory(10000)

steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        return model(
            Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1)
    else:
        return LongTensor([[random.randrange(2)]])


episode_durations = []


def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.FloatTensor(episode_durations)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [19]:
last_sync = 0

def optimize_model():
    global last_sync
    if len(memory) < BATCH_SIZE:
        return
    # выбираем новый батч
    transitions = memory.sample(BATCH_SIZE)
    
    # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
    # detailed explanation).
    batch = Transition(*zip(*transitions))

    # Для всех состояний считаем маску не финальнсти и конкантенируем их
    non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)))

    # Блокируем прохождение градиента для вычисления функции ценности действия
    # volatile=True
    non_final_next_states = Variable(torch.cat([s for s in batch.next_state
                                                if s is not None]),
                                     volatile=True)
    state_batch = Variable(torch.cat(batch.state))
    action_batch = Variable(torch.cat(batch.action))
    reward_batch = Variable(torch.cat(batch.reward))

    # Считаем Q(s_t, a) - модель дает Q(s_t), затем мы выбираем
    # колоки, которые соответствуют нашим действиям на щаге
    state_action_values = model(state_batch).gather(1, action_batch)

    # Подсчитываем ценность состяония V(s_{t+1}) для всех последующмх состояний.
    next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0] # берем значение максимума
    
    # Для подсчета лоса нам нужно будет разрешить прохождение градиента по переменной
    # блокировку, которого мы унаследовали
    # requires_grad=False
    next_state_values.volatile = False
    # Считаем ожидаемое значение функции оценки ценности действия  Q-values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Считаем ошибку Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)

    # Оптимизация модели
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
num_episodes = 1000

for i_episode in range(num_episodes):
    print("EPISODE NUMBER: {}".format(i_episode))
    # Инициализация среды
    env.reset()
    # last_screen = get_screen()
    # current_screen = get_screen()
    last_four_screens = [get_screen() for _ in range(2)]
    state = torch.cat(last_four_screens, 1)
    cum_reward = 0
    for t in count():
        
              # Выбрать и выполнить новое действие
        action = select_action(state)
        _, reward, done, _ = env.step(action[0, 0])
        cum_reward += reward
        reward = Tensor([reward])

        # Получаем новое состояние
        last_four_screens.pop(0)
        last_four_screens.append(get_screen())
        if not done:
            next_state = torch.cat(last_four_screens, 1)
        else:
            next_state = None

        # Сохраняем состояние, следующее состояние, награду и действие в память
        memory.push(state, action, next_state, reward)

        # Переходим в новое состояние
        state = next_state

        # Шаг оптимизации
        optimize_model()
        if done:
            break

print('Complete')
env.render(close=True)
env.close()
plt.ioff()
plt.show()

EPISODE NUMBER: 0
EPISODE NUMBER: 1
EPISODE NUMBER: 2
EPISODE NUMBER: 3
EPISODE NUMBER: 4
EPISODE NUMBER: 5
EPISODE NUMBER: 6
EPISODE NUMBER: 7
EPISODE NUMBER: 8
EPISODE NUMBER: 9
EPISODE NUMBER: 10
EPISODE NUMBER: 11
EPISODE NUMBER: 12
EPISODE NUMBER: 13
EPISODE NUMBER: 14
EPISODE NUMBER: 15
EPISODE NUMBER: 16
EPISODE NUMBER: 17
EPISODE NUMBER: 18
EPISODE NUMBER: 19
EPISODE NUMBER: 20
EPISODE NUMBER: 21
EPISODE NUMBER: 22
EPISODE NUMBER: 23
EPISODE NUMBER: 24
EPISODE NUMBER: 25
EPISODE NUMBER: 26
EPISODE NUMBER: 27
EPISODE NUMBER: 28
EPISODE NUMBER: 29
EPISODE NUMBER: 30
EPISODE NUMBER: 31
EPISODE NUMBER: 32
EPISODE NUMBER: 33
EPISODE NUMBER: 34
EPISODE NUMBER: 35
EPISODE NUMBER: 36
EPISODE NUMBER: 37
EPISODE NUMBER: 38
EPISODE NUMBER: 39
EPISODE NUMBER: 40
EPISODE NUMBER: 41
EPISODE NUMBER: 42
EPISODE NUMBER: 43
EPISODE NUMBER: 44
EPISODE NUMBER: 45
EPISODE NUMBER: 46
EPISODE NUMBER: 47
EPISODE NUMBER: 48
EPISODE NUMBER: 49
EPISODE NUMBER: 50
EPISODE NUMBER: 51
EPISODE NUMBER: 52
EPI