In [None]:
!sudo pip3 install pygame -q 

In [2]:
import snakai
import agents
import numpy as np
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T
import datetime



## Define how to play game and replay memory

In [3]:
game_size = (10,10)
def tuple_to_torch(tup):
    return torch.from_numpy(np.array(tup))

action2ind = {'right' : 0,
             'left' : 1,
             'up' : 2,
             'down' : 3}
ind2action = {val: key for key, val in action2ind.items()}


In [4]:
def play_game(snake, agent, epsilon = 0.05):
    cum_reward = 0.0
    snake.on_init()
    state, reward, ended = snake.on_feedback()

    for i in range(200):
        action = agent(state, th = epsilon)
        next_state, reward, ended = snake.step(action)
        cum_reward += float(reward)
        
        # Keep all the games:
        memory.push(state, action, next_state, reward, ended)
        state = next_state
        if ended == 1:
            return cum_reward, i
    return cum_reward, i


In [5]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward','ended'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = ReplayMemory(5000)

## Define agent

In [6]:
ch = 64
ksize = 4
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, ch, kernel_size=ksize, stride=2, padding = 2)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 1)
        self.conv3 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 0)
        #self.dense1 = nn.Linear(2592, 1024)
        self.head = nn.Linear(256, 4)

    def forward(self, x):
        #print(x.shape)
        x = F.relu(self.conv1(x))
        #print(x.shape)
        x = F.relu(self.conv2(x))
        #print(x.shape)
        x = F.relu(self.conv3(x))
        #print(x.shape)
        x = x.view(x.size(0), -1)
        # x = F.relu(self.dense1(x))
        return 2*F.tanh(self.head(x))
    
model = DQN()
batch_size = 32

In [7]:
#imitation_state_dict = torch.load("imitation_learning.pth")
#model.load_state_dict(imitation_state_dict)

In [8]:
#optimizer = optim.Adam(model.head.parameters(), lr = 0.001) # , weight_decay = 0.001
optimizer = optim.Adam(model.parameters(), lr = 0.001) # , weight_decay = 0.001

In [9]:
def train_batch():
    if len(memory) < batch_size:
        return 0
    
    # GET SAMPLE OF DATA
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    state_batch = tuple_to_torch(batch.state).float()
    next_state_batch = tuple_to_torch(batch.next_state).float()
    action_batch = tuple_to_torch(list(action2ind[a] for a in batch.action))
    reward_batch = tuple_to_torch(batch.reward).float()


    ## Calculate expected reward:
    GAMMA = 0.99
    with torch.set_grad_enabled(False):
        not_ended_batch = 1 -torch.ByteTensor(batch.ended)
        next_states_non_final = next_state_batch[not_ended_batch]
        next_state_values = torch.zeros(batch_size)
        reward_hat = model(next_states_non_final)
        next_state_values[not_ended_batch] = reward_hat.max(1)[0]
        expected_state_action_values = next_state_values*GAMMA + reward_batch


    # Predict value function:
    yhat = model(state_batch)
    state_action_values = yhat.gather(1, action_batch.unsqueeze(1)).squeeze()

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.data.clamp_(-1, 1)
    optimizer.step()

    return loss.data

In [10]:
def deep_agent(state, th):
    
    if random.random() < th:
        return random.sample(list(ind2action.values()), 1)[0]
    
    state = torch.unsqueeze(torch.from_numpy(state),0).float()
    yhat = model(state)
    action = [ind2action[a] for a in yhat.argmax(1).data.numpy()]
    if len(action) > 1:
        raise Exception
    action = action[0]
    return action

In [11]:
snake = snakai.Snake(render=False, 
                     game_size = game_size, 
                     time_reward = -0.01)

# Warmup memory:
for _ in range(10):
    play_game(snake, deep_agent, epsilon = 1.0)

In [12]:
def evaluate_agent(n = 100, epsilon = 0.05):
    rewards = np.zeros(n)
    for ep in range(n):
        rewards[ep],i = play_game(snake, deep_agent, epsilon = epsilon)
        
    return np.mean(rewards)

def save_checkpoint():
    filename = "models/snake_ep:%02d-reward:%.2f.pth" %( ep, eval_reward)
    torch.save(model.state_dict(), filename)
    return True

In [None]:
REPORT_INTERVAL = 100
EVAL_INTERVAL = 2000
R = []
L = []
play_length = []

EPS_START = 0.9
EPS_END = 0.05
decay = 0.1/2000
start_ep = 0

for ep in range(100000):
    
    # Play one game:
    epsilon = max(EPS_START - decay*(ep), EPS_END)
    r, i = play_game(snake, deep_agent, epsilon = epsilon)
    R.append(r)
    play_length.append(i)
    
    # Train:
    for _ in range(10):
        l = train_batch()
        L.append(float(l))
    
    if ep % REPORT_INTERVAL == 0:
        print("%s: ep: %s \t reward: %.3f \t loss: %.4f \t game len: %.1f \t epsilon: %.2f" % 
              (str(datetime.datetime.now()), ep, np.mean(R), np.mean(L), np.mean(play_length), epsilon))
        R = []
        L = []
        play_length = []
    
    if ep % EVAL_INTERVAL == 0:
        eval_reward = evaluate_agent()
        save_checkpoint()
        print("%s: ep: %s \t Reward evaluation: %.2f" % (str(datetime.datetime.now()), ep, eval_reward))

2018-05-27 11:04:58.410492: ep: 0 	 reward: -1.010 	 loss: 0.0403 	 game len: 1.0 	 epsilon: 0.90
2018-05-27 11:04:58.791300: ep: 0 	 Reward evaluation: -0.93
2018-05-27 11:05:29.820608: ep: 100 	 reward: -0.950 	 loss: 0.0096 	 game len: 2.0 	 epsilon: 0.90
2018-05-27 11:06:01.321170: ep: 200 	 reward: -0.962 	 loss: 0.0063 	 game len: 2.3 	 epsilon: 0.89
2018-05-27 11:06:32.523294: ep: 300 	 reward: -0.924 	 loss: 0.0091 	 game len: 2.4 	 epsilon: 0.89
2018-05-27 11:07:03.796442: ep: 400 	 reward: -0.943 	 loss: 0.0115 	 game len: 2.4 	 epsilon: 0.88
2018-05-27 11:07:35.210666: ep: 500 	 reward: -0.982 	 loss: 0.0132 	 game len: 2.3 	 epsilon: 0.88
2018-05-27 11:08:06.640280: ep: 600 	 reward: -0.972 	 loss: 0.0128 	 game len: 2.2 	 epsilon: 0.87
2018-05-27 11:08:37.941304: ep: 700 	 reward: -0.954 	 loss: 0.0125 	 game len: 2.5 	 epsilon: 0.86
2018-05-27 11:09:09.125471: ep: 800 	 reward: -0.931 	 loss: 0.0125 	 game len: 2.2 	 epsilon: 0.86
2018-05-27 11:09:40.291433: ep: 900 	 rew

In [17]:
# Training time now:
ep

98510

In [21]:
# evaluate agent with 5% epsilon greedy policy:
evaluate_agent(n = 1000, epsilon = 0.05)

3.785280000000004

In [22]:
# Evaluate agent with greedy policy:
evaluate_agent(n = 1000, epsilon = 0.0)

11.315360000000009

In [16]:
snake = snakai.Snake(render=False, 
                     game_size = game_size, 
                     time_reward = -0.01)
snake.on_init()
state, reward, done = snake.on_feedback()

for _ in range(10):
    print(play_game(snake, deep_agent, epsilon = 0.0))

(-0.9900000000000014, 199)
(13.100000000000016, 105)
(-0.9900000000000014, 199)
(8.760000000000005, 32)
(19.689999999999994, 153)
(13.930000000000021, 123)
(2.670000000000001, 37)
(20.569999999999993, 166)
(12.75000000000002, 140)
(11.030000000000017, 110)
