In [1]:
import snakai
import agents
import numpy as np
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T
import datetime


# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

action2ind = {'right' : 0,
             'left' : 1,
             'up' : 2,
             'down' : 3}
ind2action = {val: key for key, val in action2ind.items()}
ind2action

{0: 'right', 1: 'left', 2: 'up', 3: 'down'}

## Replay memory

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward','ended'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = ReplayMemory(5000)

## Q-network

In [3]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        #self.dense1 = nn.Linear(1600, 2048)
        self.dense2 = nn.Linear(1600, 512)
        self.head = nn.Linear(512, 4)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        #x = F.relu(self.bn2(self.conv2(x)))
        #x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)
        #x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))
        return (self.head(x))
    
model = DQN()

In [4]:
def torch_step(action):
    model.train(mode=False)
    if 'torch' in action.type():
        action_pure = ind2action[action.numpy()[0][0]]
        next_state, reward, ended = snake.step(action_pure)
        next_state, reward, ended = torch.unsqueeze(torch.from_numpy(next_state),0).float(), FloatTensor([[reward]]), LongTensor([[ended]])
        return next_state, reward, ended
    else:
        return snake.step(action)

In [5]:
def max_idx(a):
    one = a == np.max(a)
    col = np.argmax(one.max(axis=0))
    row = np.argmax(one[:,col])
    return row, col

def simple_agent(state):
    use_torch = 'torch' in state.type()
    if use_torch:
        state = state.numpy()
        
    y_player, x_player = max_idx(state[0,0,:,:])
    y_apple, x_apple = max_idx(state[0,1,:,:])
    action = "down"
    if y_player < y_apple:
        action = "down"
    elif y_player > y_apple:
        action = "up"
    elif x_player < x_apple:
        action = "right"
    elif x_player > x_apple:
        action = "left"
    
    action = action2ind[action]
    if use_torch:
        return LongTensor([[action]])
    else: 
        return action
    
def random_agent(state):
    return LongTensor([[random.randrange(4)]])

def model_agent(state):
    return model(Variable(state)).data.max(1)[1].view(1, 1)

def epsilon_agent(state, th = 0.05):
    if random.random() > th:
        return model_agent(state)
    else:
        return random_agent(state)

In [6]:
def optimize_model():
    
    if len(memory) < batch_size:
        return None, 0

    # fetch and concat batch:
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))

    state_batch = Variable(torch.cat(batch.state))
    action_batch = Variable(torch.cat(batch.action))
    reward_batch = Variable(torch.cat(batch.reward))

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken
    model.train(mode=True)
    state_action_values = model(state_batch).gather(1, action_batch)

    ended_batch = torch.cat(batch.ended)
    non_final_mask = ByteTensor(1 - ended_batch.numpy())
    
    non_final_next_states = Variable(torch.cat(
    [state for end, state in zip(ended_batch.numpy().flatten(), batch.next_state) if end !=1])
                                     ,volatile=True)
    

    # Compute V(s_{t+1}) for all next states.
    model.train(mode=False)
    next_state_values = Variable(torch.zeros(batch_size).type(Tensor))
    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]

    # Now, we don't want to mess up the loss with a volatile flag, so let's
    # clear it. After this, we'll just end up with a Variable that has
    # requires_grad=False
    next_state_values.volatile = False
    
    
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch.view(batch_size).float()
    #print(reward_batch)
    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    #for param in model.parameters():
    #    param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss, expected_state_action_values.mean()

In [7]:
def play_game(snake, epsilon = 0.05):
    cum_reward = 0.0
    snake.on_init()
    state, reward, ended = snake.on_feedback()
    state = torch.unsqueeze(torch.from_numpy(state),0).float()
    for i in range(ep_length):
        action = epsilon_agent(state, th = epsilon)
        next_state, reward, ended = torch_step(action)
        cum_reward += float(reward)
        
        memory.push(state, action, next_state, reward, ended)
        state = next_state
        if ended.numpy()[0][0] == 1:
            return cum_reward, i


## Train

In [8]:
start = datetime.datetime.now()
game_size = (10, 10)
snake = snakai.Snake(render=False, game_size = game_size, time_reward = 0.01)

ep_length = 10000
num_episode = 100000
avg_reward = -1.0
avg_steps = 1.0
best_reward = -1.0

In [9]:
ch = 32
ksize = 4
batch_size = 64
GAMMA = 0.99
optimizer = optim.RMSprop(model.parameters(), lr = 0.001, weight_decay = 0.001)


class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, ch, kernel_size=ksize, stride=1, padding = 0)
        self.bn1 = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 2)
        self.bn2 = nn.BatchNorm2d(ch)
        self.conv3 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 2)
        self.bn3 = nn.BatchNorm2d(ch)
        self.dense1 = nn.Linear(2592, 512)
        self.head = nn.Linear(512, 4)

    def forward(self, x):
        #print(x.shape)
        x = F.relu(self.bn1(self.conv1(x)))
        #print(x.shape)
        x = F.relu(self.bn2(self.conv2(x)))
        #print(x.shape)
        x = F.relu(self.bn3(self.conv3(x)))
        #print(x.shape)
        x = x.view(x.size(0), -1)
        x = F.relu(self.dense1(x))
        return (self.head(x))
    
model = DQN()

In [10]:
ch = 32
ksize = 4
batch_size = 64
GAMMA = 0.99
optimizer = optim.RMSprop(model.parameters(), lr = 0.001, weight_decay = 0.001)


class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, ch, kernel_size=ksize, stride=1, padding = 0)
        self.bn1 = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 2)
        self.conv3 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 2)
        self.dense1 = nn.Linear(2592, 512)
        self.head = nn.Linear(512, 4)

    def forward(self, x):
        #print(x.shape)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #print(x.shape)
        x = F.relu(self.conv3(x))
        #print(x.shape)
        x = x.view(x.size(0), -1)
        x = F.relu(self.dense1(x))
        return (self.head(x))
    
model = DQN()

In [11]:
play_game(snake, epsilon = 0.2)

(0.07999999821186066, 9)

In [12]:
for i_episode in range(num_episode):
    
    cum_reward, steps = play_game(snake, epsilon = 0.1)
    
    for _ in range(4):
        l, exp_val = optimize_model()
    
    avg_steps = float(steps)*0.01 + avg_steps*0.99
    avg_reward = float(cum_reward)*0.01 + avg_reward*0.99
    if i_episode % 100 == 0 and l is not None:
        print('%s episode %d: avg_reward: %.3f, steps: %d loss: %.2f, exp_val: %.2f' % 
              (str(datetime.datetime.now() - start), i_episode, avg_reward, avg_steps, l.data[0], exp_val.data[0]))
        
        if best_reward < avg_reward and i_episode % 500 == 0:
            print("saving model..")
            torch.save(model, "best_model.torch")
            best_reward = avg_reward
            
        start = datetime.datetime.now()

0:00:33.898656 episode 100: avg_reward: -0.342, steps: 3 loss: 0.16, exp_val: 0.04
0:00:35.476840 episode 200: avg_reward: -0.142, steps: 4 loss: 0.17, exp_val: 0.06
0:00:35.655435 episode 300: avg_reward: 0.001, steps: 5 loss: 0.18, exp_val: 0.01
0:00:35.435553 episode 400: avg_reward: 0.077, steps: 5 loss: 0.19, exp_val: -0.06
0:00:35.403011 episode 500: avg_reward: 0.111, steps: 5 loss: 0.10, exp_val: 0.11
saving model..


  "type " + obj.__name__ + ". It won't be checked "


0:00:35.405853 episode 600: avg_reward: 0.075, steps: 5 loss: 0.15, exp_val: 0.04
0:00:35.391613 episode 700: avg_reward: 0.053, steps: 5 loss: 0.12, exp_val: -0.00
0:00:35.551344 episode 800: avg_reward: 0.085, steps: 5 loss: 0.11, exp_val: 0.11
0:00:35.517134 episode 900: avg_reward: 0.060, steps: 5 loss: 0.17, exp_val: 0.03
0:00:35.591884 episode 1000: avg_reward: 0.007, steps: 5 loss: 0.13, exp_val: 0.08
0:00:35.499971 episode 1100: avg_reward: 0.040, steps: 5 loss: 0.14, exp_val: 0.03
0:00:35.630427 episode 1200: avg_reward: 0.038, steps: 5 loss: 0.15, exp_val: 0.09
0:00:35.523032 episode 1300: avg_reward: 0.045, steps: 5 loss: 0.16, exp_val: 0.11
0:00:35.430479 episode 1400: avg_reward: 0.016, steps: 5 loss: 0.14, exp_val: 0.06
0:00:35.427241 episode 1500: avg_reward: 0.052, steps: 5 loss: 0.14, exp_val: 0.17
0:00:35.467064 episode 1600: avg_reward: 0.075, steps: 5 loss: 0.13, exp_val: 0.01
0:00:35.468855 episode 1700: avg_reward: 0.031, steps: 5 loss: 0.16, exp_val: 0.07
0:00:34

## Test with greedy-policy

In [13]:
model = torch.load("best_model.torch")
snake = snakai.Snake(render=True, game_size = game_size)
while True:
    cum_reward, steps = play_game(snake, epsilon = 0.0)

KeyboardInterrupt: 

In [None]:
avg_reward