# Implementation of hex training

In [567]:
from hex_engine import hexPosition

from collections import namedtuple, deque
from itertools import count

import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [568]:
myboard = hexPosition(size=5)

myboard.printBoard()

     A   B   C   D   E
     _   _   _   _   _
    / \_/ \_/ \_/ \_/ \
   |   |   |   |   |   | 1 
    \_/ \_/ \_/ \_/ \_/ \
     |   |   |   |   |   | 2 
      \_/ \_/ \_/ \_/ \_/ \
       |   |   |   |   |   | 3 
        \_/ \_/ \_/ \_/ \_/ \
         |   |   |   |   |   | 4 
          \_/ \_/ \_/ \_/ \_/ \
           |   |   |   |   |   | 5 
            \_/ \_/ \_/ \_/ \_/
             A   B   C   D   E


In [569]:
class DQN(nn.Module):

    def __init__(self, size):
        super(DQN, self).__init__()

        self.input_layer = nn.Linear(size * size, 64)
        self.output_layer = nn.Linear(64, size * size)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.input_layer(x))
        return self.output_layer(x)

In [570]:

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Training

In [571]:
policy_net = DQN(myboard.size).to(device)
target_net = DQN(myboard.size).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()


optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)


steps_done = 0
n_actions = myboard.size * myboard.size

In [572]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        available_actions = myboard.getActionSpace()

        pos = random.sample(available_actions, 1)[0]
        return torch.tensor([[pos[0] * myboard.size + pos[1]]], device=device, dtype=torch.long)

In [573]:
A = [1,2,3]

[a for a in A if a != 2]

[1, 3]

In [574]:

episode_durations = []

def plot_durations():
    # print board
    plt.figure(1)
    plt.clf()

    for x in range(myboard.size):
        for y in range(myboard.size):
            if myboard.board[y][x] == 1:
                plt.plot(x, myboard.size-y-1, color='red', marker='o')
            elif myboard.board[y][x] == 2:
                plt.plot(x, myboard.size - y -1, color='blue', marker='o')
    # y = [[i,2] for i in range(1, myboard.size)]
    
    # plt.scatter(x, x, data=[0,0,0], s=100)

    # print learning rate
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [575]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [576]:
def get_state():
    state = torch.tensor(myboard.board, dtype=torch.float32).view(1, -1) - 1
    return state

def do_step(action, white_plays):

    row = action // myboard.size
    col = action % myboard.size

    myboard.board[row][col] = 1 if white_plays else 2

    if white_plays:
        return 1 if myboard.whiteWin() else 0
    else:
        return 1 if myboard.blackWin() else 0

In [577]:

num_episodes = 50
for i_episode in range(num_episodes):
    # Initialize the environment and state
    myboard.reset()
    last_state = get_state()
    current_state = get_state()
    state = current_state - last_state

    white_plays = True

    for t in count():
        # Select and perform an action
        action = select_action(state)
        reward = do_step(action.item(), white_plays)
        white_plays = not white_plays
        
        reward = torch.tensor([reward], device=device)

        # Observe new state
        last_state = current_state
        current_state = get_state()
        if not myboard.winner:
            next_state = current_state - last_state
        else:
            next_state = None

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()
        if myboard.winner:
            episode_durations.append(t + 1)
            plot_durations()
            
            break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

In [578]:
myboard.printBoard()

     A   B   C   D   E
     _   _   _   _   _
    / \_/ \_/ \_/ \_/ \
   | ● | ● | ○ | ○ | ● | 1 
    \_/ \_/ \_/ \_/ \_/ \
     | ○ | ○ | ● | ● | ○ | 2 
      \_/ \_/ \_/ \_/ \_/ \
       | ● | ● |   | ○ | ○ | 3 
        \_/ \_/ \_/ \_/ \_/ \
         | ● | ○ | ● | ● | ○ | 4 
          \_/ \_/ \_/ \_/ \_/ \
           | ● | ● |   | ● |   | 5 
            \_/ \_/ \_/ \_/ \_/
             A   B   C   D   E
