In [1]:
#!sudo pip3 install pygame -q 

In [2]:
import snakai
import agents
import numpy as np
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T
import datetime



## Define how to play game and replay memory

In [3]:
game_size = (50,50)
def tuple_to_torch(tup):
    return torch.from_numpy(np.array(tup))

action2ind = {'right' : 0,
             'left' : 1,
             'up' : 2,
             'down' : 3}
ind2action = {val: key for key, val in action2ind.items()}


In [4]:
def play_game(snake, agent, epsilon = 0.05):
    cum_reward = 0.0
    snake.on_init()
    state, reward, ended = snake.on_feedback()

    for i in range(200):
        action = agent(state, th = epsilon)
        next_state, reward, ended = snake.step(action)
        cum_reward += float(reward)
        
        # Keep all the games:
        memory.push(state, action, next_state, reward, ended)
        state = next_state
        if ended == 1:
            return cum_reward, i
    return cum_reward, i


In [5]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward','ended'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = ReplayMemory(5000)

## Define agent

In [6]:
ch = 64
ksize = 4
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, ch, kernel_size=ksize, stride=2, padding = 1)
        self.conv2 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=2, padding = 1)
        self.conv3 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=2, padding = 1)
        self.conv4 = nn.Conv2d(ch, ch, kernel_size=ksize, stride=1, padding = 0)

        #self.dense1 = nn.Linear(2592, 1024)
        self.head = nn.Linear(576, 4)

    def forward(self, x):
        #print(x.shape)
        x = F.relu(self.conv1(x))
        #print(x.shape)
        x = F.relu(self.conv2(x))
        #print(x.shape)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        #print(x.shape)
        x = x.view(x.size(0), -1)
        # x = F.relu(self.dense1(x))
        return 2*F.tanh(self.head(x))
    
model = DQN()
batch_size = 32

In [7]:
#imitation_state_dict = torch.load("imitation_learning.pth")
#model.load_state_dict(torch.load("models/snake_ep:62000-reward:4.32"))

In [8]:
#optimizer = optim.Adam(model.head.parameters(), lr = 0.001) # , weight_decay = 0.001
optimizer = optim.Adam(model.parameters(), lr = 0.001) # , weight_decay = 0.001

In [9]:
def train_batch():
    if len(memory) < batch_size:
        return 0
    
    # GET SAMPLE OF DATA
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    state_batch = tuple_to_torch(batch.state).float()
    next_state_batch = tuple_to_torch(batch.next_state).float()
    action_batch = tuple_to_torch(list(action2ind[a] for a in batch.action))
    reward_batch = tuple_to_torch(batch.reward).float()


    ## Calculate expected reward:
    GAMMA = 0.99
    with torch.set_grad_enabled(False):
        not_ended_batch = 1 -torch.ByteTensor(batch.ended)
        next_states_non_final = next_state_batch[not_ended_batch]
        next_state_values = torch.zeros(batch_size)
        reward_hat = model(next_states_non_final)
        next_state_values[not_ended_batch] = reward_hat.max(1)[0]
        expected_state_action_values = next_state_values*GAMMA + reward_batch


    # Predict value function:
    yhat = model(state_batch)
    state_action_values = yhat.gather(1, action_batch.unsqueeze(1)).squeeze()

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.data.clamp_(-1, 1)
    optimizer.step()

    return loss.data

In [10]:
def deep_agent(state, th):
    
    if random.random() < th:
        return random.sample(list(ind2action.values()), 1)[0]
    
    state = torch.unsqueeze(torch.from_numpy(state),0).float()
    yhat = model(state)
    action = [ind2action[a] for a in yhat.argmax(1).data.numpy()]
    if len(action) > 1:
        raise Exception
    action = action[0]
    return action

In [11]:
snake = snakai.Snake(render=False, 
                     game_size = game_size, 
                     time_reward = -0.01)

# Warmup memory:
for _ in range(10):
    play_game(snake, deep_agent, epsilon = 1.0)

In [12]:
def evaluate_agent(n = 100, epsilon = 0.05):
    rewards = np.zeros(n)
    for ep in range(n):
        rewards[ep],i = play_game(snake, deep_agent, epsilon = epsilon)
        
    return np.mean(rewards)

def save_checkpoint():
    filename = "models/snakeBig_ep:%02d-reward:%.2f.pth" %( ep, eval_reward)
    torch.save(model.state_dict(), filename)
    return True

In [13]:
REPORT_INTERVAL = 100
EVAL_INTERVAL = 2000
R = []
L = []
play_length = []

EPS_START = 0.9
EPS_END = 0.05
decay = 0.1/2000
start_ep = 0

for ep in range(100000):
    
    # Play one game:
    epsilon = max(EPS_START - decay*(ep), EPS_END)
    r, i = play_game(snake, deep_agent, epsilon = epsilon)
    R.append(r)
    play_length.append(i)
    
    # Train:
    for _ in range(10):
        l = train_batch()
        L.append(float(l))
    
    if ep % REPORT_INTERVAL == 0:
        print("%s: ep: %s \t reward: %.3f \t loss: %.4f \t game len: %.1f \t epsilon: %.2f" % 
              (str(datetime.datetime.now()), ep, np.mean(R), np.mean(L), np.mean(play_length), epsilon))
        R = []
        L = []
        play_length = []
    
    if ep % EVAL_INTERVAL == 0:
        eval_reward = evaluate_agent()
        save_checkpoint()
        print("%s: ep: %s \t Reward evaluation: %.2f" % (str(datetime.datetime.now()), ep, eval_reward))

2018-05-28 06:11:53.353026: ep: 0 	 reward: -1.100 	 loss: 0.0566 	 game len: 10.0 	 epsilon: 0.90
2018-05-28 06:12:01.387666: ep: 0 	 Reward evaluation: -1.21
2018-05-28 08:45:54.276223: ep: 3400 	 reward: -1.034 	 loss: 0.0134 	 game len: 3.4 	 epsilon: 0.73
2018-05-28 08:50:24.036823: ep: 3500 	 reward: -1.039 	 loss: 0.0140 	 game len: 3.9 	 epsilon: 0.72
2018-05-28 08:54:55.971099: ep: 3600 	 reward: -1.039 	 loss: 0.0125 	 game len: 3.9 	 epsilon: 0.72
2018-05-28 08:59:24.513618: ep: 3700 	 reward: -1.036 	 loss: 0.0139 	 game len: 3.6 	 epsilon: 0.72
2018-05-28 09:03:12.413293: ep: 3800 	 reward: -1.040 	 loss: 0.0120 	 game len: 4.0 	 epsilon: 0.71
2018-05-28 09:05:57.748268: ep: 3900 	 reward: -1.048 	 loss: 0.0136 	 game len: 4.8 	 epsilon: 0.71
2018-05-28 09:08:13.827888: ep: 4000 	 reward: -1.036 	 loss: 0.0110 	 game len: 3.6 	 epsilon: 0.70
2018-05-28 09:08:21.119081: ep: 4000 	 Reward evaluation: -1.21
2018-05-28 09:10:40.585103: ep: 4100 	 reward: -1.042 	 loss: 0.0245 

KeyboardInterrupt: 

In [None]:
# Training time now:
ep

In [None]:
# evaluate agent with 5% epsilon greedy policy:
evaluate_agent(n = 1000, epsilon = 0.05)

In [None]:
# Evaluate agent with greedy policy:
evaluate_agent(n = 1000, epsilon = 0.0)

In [None]:
snake = snakai.Snake(render=True, 
                     game_size = game_size, 
                     time_reward = -0.01)
snake.on_init()
state, reward, done = snake.on_feedback()

for _ in range(10):
    print(play_game(snake, deep_agent, epsilon = 0.0))