# Hyperparameters

In [None]:
LR = 1.e-4
WEIGHT_DECAY = 1.e-4

# NN Agent Definition

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import *
import numpy as np
import random

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        
        # Shared backbone
        self.conv = nn.Conv2d(1, 16, kernel_size=2, stride=1, bias=False) # the result is 16*2*2
        self.size = 2*2*16
        self.fc = nn.Linear(self.size, 32)
        
        # Layery for the policy (action)
        self.fc_action_1 = nn.Linear(32, 16)
        self.fc_action_2 = nn.Linear(16, 9)
        
        # Layers for the critic (value)
        self.fc_value_1 = nn.Linear(32, 8)
        self.fc_value_2 = nn.Linear(8, 1)
        self.tanh_value = nn.Tanh() # value will be between -1 and 1
        
    def forward(self, x):
        # Backbone process
        y = F.relu(self.conv(x)) # the output has shape (batch_size 2, 2, 16), needs reshaping
        y = y.view(-1, self.size) # reshaping
        y = F.relu(self.fc(y))
        
        # Policy head
        a = F.relu(self.fc_action_1(y))
        a = self.fc_action_2(a)
        
        # remove unavailable actions
        avail = (torch.abs(x.squeeze())!=1).type(torch.FloatTensor).view(-1, 9)
        
        # subtract off max for numerical stability (avoids large numbers when taking the exponential)
        max_a = torch.max(a)
        stable = a-max_a 
        
        # Softmax only on available actions
        exp = torch.exp(stable)
        exp = avail*stable # zero out unavailable actions
        prob = exp/torch.sum(exp) # normalize -> SOFTMAX
        
        # Critic head
        value = F.relu(self.fc_value_1(y))
        value = self.tanh_value(self.fc_value_2(value))
        
        return prob.view(3,3), value    
        
        

# Players Definition: 

1. Random Player
2. MCTS NN Player

In [None]:
import MCTS
from copy import copy
import random

def Random_Player(game):
    return random.choice(game.available_actions())

def Policy_Player_MCTS(game, policy, explore_steps=50):
    mytree = MCTS.Node(copy(game))
    for _ in range(explore_steps):
        mytree.explore(policy)
        
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)

    return mytreenext.game.last_move

# Training Loop

In [5]:
# Progress bar
import progressbar as pb
widget = ['training loop: ', pb.Percentage(), ' ', 
          pb.Bar(), ' ', pb.ETA() ]
timer = pb.ProgressBar(widgets=widget, maxval=episodes).start()

training loop:   0% |                                          | ETA:  --:--:--

In [None]:
def take_turn(tree, policy, explore_steps=50):
    for _ in range(50):
        tree.explore(policy) # does not return anything. Only performs the exploration, and updates the tree
    player = tree.game.player
    next_tree, (v, nn_v, p, nn_p) = tree.next(temperature=0.1) # chooses the action, and returns the next tree
    return next_tree, (v, nn_v, p, nn_p), player

In [None]:
# Instantiate the policy, game and oprimizer
import torch.optim as optim
from ConnectN import ConnectN

game_setting = {'size':(3,3), 'N':3}
# game = ConnectN(**game_setting)
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=LR, weight_decay=1.e-4)

In [None]:
# # Training loop
# from collections import deque

def train(game_setting, policy, episodes, explore_steps=50):
    outcomes = []
    losses = []
    
    for e in range(episodes):
        # mytree = MCTS.Node(ConnectN(**game_setting)) # New tree for each episode, receives a new game each time
        new_game = ConnectN(**game_setting)
        mytree = MCTS.Node(new_game)
        
        vterm = []
        logterm = []
        while mytree.outcome is None: # Play until the end of the game
            mytree, (v, nn_v, p, nn_p), current_player = take_turn(mytree, policy, explore_steps=explore_steps)
            mytree.detach_mother()
            
            # Compute log_prob ---------> logterm = - sum(log(nn_p) - p*log(p))
            loglist = torch.log(nn_p)
            constant = torch.where(p>0, p*torch.log(p), torch.tensor(0.))
            logterm.append(-torch.sum(loglist-constant))
            
            vterm.append(nn_v*current_player) # adjust sign of vterm for each player
            
        outcome = mytree.outcome
        outcomes.append(outcome)
        
        loss = torch.sum((torch.stack(vterm)-outcome)**2 + torch.stack(logterm)) # value + policy losses
        optimizer.zero_grad() # clean the gradients
        loss.backward() # Calculates the gradients
        losses.append(float(loss))
        optimizer.step() # update the parameters of the model
        
        if (e+1)%50==0:
            print("game: ",e+1, ", mean loss: {:3.2f}".format(np.mean(losses[-20:])),
                ", recent outcomes: ", outcomes[-10:])
            
        del loss
        
        timer.update(e+1)
    
    timer.finish()

In [None]:
# Training parameters
episodes = 400

train(game_setting, policy, episodes, explore_steps=50)