In [1]:
import snakai
import agents
import numpy as np
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T
import datetime



## Define how to play game and replay memory

In [2]:
game_size = (10,10)
def tuple_to_torch(tup):
    return torch.from_numpy(np.array(tup))

action2ind = {'right' : 0,
             'left' : 1,
             'up' : 2,
             'down' : 3}
ind2action = {val: key for key, val in action2ind.items()}


In [3]:
def play_game(snake, agent, epsilon = 0.05):
    cum_reward = 0.0
    snake.on_init()
    state, reward, ended = snake.on_feedback()

    for i in range(200):
        action = agent(state, th = epsilon)
        next_state, reward, ended = snake.step(action)
        cum_reward += float(reward)
        
        # Keep all the games:
        memory.push(state, action, next_state, reward, ended)
        state = next_state
        if ended == 1:
            return cum_reward, i
    return cum_reward, i


In [4]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward','ended'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = ReplayMemory(5000)

## Define agent

In [5]:
ch = 64
ksize = 4
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, ch, kernel_size=ksize, stride=2, padding = 2)
        self.bn1 = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 1)
        self.conv3 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 0)
        #self.dense1 = nn.Linear(2592, 1024)
        self.head = nn.Linear(256, 4)

    def forward(self, x):
        #print(x.shape)
        x = F.relu(self.conv1(x))
        #print(x.shape)
        x = F.relu(self.conv2(x))
        #print(x.shape)
        x = F.relu(self.conv3(x))
        #print(x.shape)
        x = x.view(x.size(0), -1)
        # x = F.relu(self.dense1(x))
        return 2*F.tanh(self.head(x))
    
model = DQN()
batch_size = 32

In [6]:
imitation_state_dict = torch.load("imitation_learning.pth")
#model.load_state_dict(imitation_state_dict)

In [7]:
#optimizer = optim.Adam(model.head.parameters(), lr = 0.001) # , weight_decay = 0.001
optimizer = optim.Adam(model.parameters(), lr = 0.001) # , weight_decay = 0.001

In [8]:
def train_batch():
    if len(memory) < batch_size:
        return 0
    
    # GET SAMPLE OF DATA
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    state_batch = tuple_to_torch(batch.state).float()
    next_state_batch = tuple_to_torch(batch.next_state).float()
    action_batch = tuple_to_torch(list(action2ind[a] for a in batch.action))
    reward_batch = tuple_to_torch(batch.reward).float()


    ## Calculate expected reward:
    GAMMA = 0.99
    with torch.set_grad_enabled(False):
        not_ended_batch = 1 -torch.ByteTensor(batch.ended)
        next_states_non_final = next_state_batch[not_ended_batch]
        next_state_values = torch.zeros(batch_size)
        reward_hat = model(next_states_non_final)
        next_state_values[not_ended_batch] = reward_hat.max(1)[0]
        expected_state_action_values = next_state_values*GAMMA + reward_batch


    # Predict value function:
    yhat = model(state_batch)
    state_action_values = yhat.gather(1, action_batch.unsqueeze(1)).squeeze()

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.data.clamp_(-1, 1)
    optimizer.step()

    return loss.data

In [9]:
def deep_agent(state, th):
    
    if random.random() < th:
        return random.sample(list(ind2action.values()), 1)[0]
    
    state = torch.unsqueeze(torch.from_numpy(state),0).float()
    yhat = model(state)
    action = [ind2action[a] for a in yhat.argmax(1).data.numpy()]
    if len(action) > 1:
        raise Exception
    action = action[0]
    return action

In [10]:
snake = snakai.Snake(render=False, 
                     game_size = game_size, 
                     time_reward = -0.01)

# Warmup memory:
for _ in range(10):
    play_game(snake, deep_agent, epsilon = 0.9)

In [None]:
REPORT_INTERVAL = 100
R = []
L = []
play_length = []
for ep in range(10000):
    
    # Play one game:
    r, i = play_game(snake, deep_agent, epsilon = 0.5)
    R.append(r)
    play_length.append(i)
    
    # Train:
    for _ in range(10):
        l = train_batch()
        L.append(float(l))
    
    if ep % REPORT_INTERVAL == 0:
        print("%s: ep: %s \t reward: %.3f \t loss: %.4f \t game len: %.1f" % 
              (str(datetime.datetime.now()), ep, np.mean(R), np.mean(L), np.mean(play_length)))
        R = []
        L = []
        play_length = []

2018-05-27 11:51:29.220465: episodes: 0, reward: -1.050, loss: 0.0135, game length: 5.0
2018-05-27 11:51:53.874540: episodes: 100, reward: -0.964, loss: 0.0113, game length: 6.5
2018-05-27 11:52:19.607622: episodes: 200, reward: -0.909, loss: 0.0112, game length: 6.0
2018-05-27 11:52:45.438028: episodes: 300, reward: -0.881, loss: 0.0097, game length: 5.3
2018-05-27 11:53:11.695321: episodes: 400, reward: -0.868, loss: 0.0105, game length: 6.0


In [None]:
for layer in model.parameters():
    print(layer.require_grad)

In [None]:
snake = snakai.Snake(render=True, 
                     game_size = game_size, 
                     time_reward = -0.1)
snake.on_init()
state, reward, done = snake.on_feedback()

for _ in range(10):
    print(play_game(snake, deep_agent))