# Imports

In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle
from itertools import product
import time
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

# Model

In [9]:
class MonteCarloTreeSearchNode():
    # class wide lookup table (state:MCTS_Node)
    state_tbl = {}
    # class wide state sizes
    n_rows = 3
    n_cols = 3
    # class wide set to avoid double rewarding nodes in backprop
    marked = set([])

    def __init__(self,c_param=np.sqrt(2), win_reward_scalar=1, state = ((0,0,0),(0,0,0),(0,0,0)), parent=None, player = -1, reset_state_tbl=False):
        self.state = state
        if reset_state_tbl:
            type(self).state_tbl = {}
        # results[0] -> ties, results[1] -> wins, results[-1] -> losses
        self.results = [0,0,0]

        self.parents = [parent] if parent else []
        self.children = []
        self.player = player

        # Hyperparameters
        self.c_param = c_param
        self.win_reward_scalar = win_reward_scalar
        self.tie_reward_scalar = 1 - self.win_reward_scalar

        self.n_visits = 0

        self.is_game_over_node = self.get_winner_code() is not None
        self.unexplored_actions = self.get_legal_actions()

    def get_legal_actions(self, state=None):
        if not state:
            state = self.state
        return [(row, col) \
                for row, col in product(range(type(self).n_rows),range(type(self).n_cols)) \
                if state[row][col] == 0]

    def expand(self):
        # get an unexplored action
        action = self.unexplored_actions.pop(0)
        # apply that action to create the new state (make state list to mutate then tuple for tbl key)
        new_state = self.apply_action(action=action)

        if new_state in type(self).state_tbl:
            child = type(self).state_tbl[new_state]
            child.parents.append(self)
        else:
            child = MonteCarloTreeSearchNode(c_param=self.c_param, 
                                             win_reward_scalar = self.win_reward_scalar, 
                                             state=new_state, 
                                             parent=self, 
                                             player=-self.player)
            type(self).state_tbl[new_state] = child
        self.children.append(child)
        return type(self).state_tbl[new_state]

    def q(self):
        wins = self.results[1]
        ties = self.results[0]
        return self.win_reward_scalar * wins + self.tie_reward_scalar * ties

    def explore_exploit_val(self, parent):
        return (self.q() / self.n_visits) + self.c_param * np.sqrt( np.log(parent.n_visits)/self.n_visits )
    
    def explore_exploit_val_list(self, parent):
        return [child.explore_exploit_val(parent) for child in self.children]

    def expore_exploit(self):
        choices_weights = self.explore_exploit_val_list(parent=self)
        return self.children[np.argmax(choices_weights)]
    
    def best_child(self):
        child_visits = [child.n_visits for child in self.children]
        return self.children[np.argmax(child_visits)]

    def is_fully_expanded(self):
        return len(self.unexplored_actions) == 0
    
    def get_winner_code(self, state=None):
        if not state:
            state = self.state
        game_result = None
        # check rows
        for row in range(type(self).n_rows):
            if state[row][0] == state[row][1] == state[row][2] != 0: 
                game_result = state[row][0]

        # check cols
        for col in range(type(self).n_cols):
            if state[0][col] == state[1][col] == state[2][col] != 0: 
                game_result = state[row][0]
        
        # check diagonals
        if  (state[0][0] == state[1][1] == state[2][2] != 0) or \
            (state[0][2] == state[1][1] == state[2][0] != 0):
            game_result = state[1][1] 
        
        # check if board is full
        elif len(self.get_legal_actions(state)) == 0:
            game_result = 0
        # return None => game is not over
        return game_result
    
    def tree_policy(self):
        if self.is_game_over_node: return self
        if self.is_fully_expanded():
            child = self.expore_exploit()
            return child.tree_policy()
        return self.expand()

    def apply_action(self, action, state=None):
        if not state:
            state = self.state
        new_state = list(list([col for col in row]) for row in state)
        new_state[action[0]][action[1]] = -self.player
        new_state = tuple(tuple(col for col in row) for row in new_state)
        return new_state

    def random_rollout(self, state=None):
        if not state:
            state = self.state

        winner_code = self.get_winner_code(state)
        if winner_code is not None:
            return winner_code

        legal_actions = self.get_legal_actions(state)
        random_action = legal_actions[np.random.randint(len(legal_actions))]
        new_state = self.apply_action(action = random_action, state=state)
        # print(self.get_winner_code(state), state, legal_actions, random_action, new_state, self.get_winner_code(new_state))
        return self.random_rollout(new_state)
    
    def backprop(self, reward):
        self.results[1]  += (reward == self.player)
        self.results[0]  += (reward == 0)
        self.results[-1] += (reward != self.player)
        self.n_visits += 1
        type(self).marked.add(id(self))
        if not self.parents: return
        for parent in self.parents:
            if id(parent) in type(self).marked: continue
            parent.backprop(reward)

        
    def train(self, n_rollouts, rollouts_sofar=0, batch_size=10000):
        # variables for runtime analysis
        total_rollouts = n_rollouts + rollouts_sofar
        batches_remaining = n_rollouts/batch_size
        batch_time = 0
        start = start_total = time.perf_counter()

        for _ in range(n_rollouts):
            # ----time prints calls----
            rollouts_sofar += 1
            if (rollouts_sofar+1) % batch_size == 0:
                prev_batch_time = batch_time
                batch_time = (time.perf_counter()-start)/60 
                start = time.perf_counter()
                batches_remaining -= 1
                time_remaining = batch_time * batches_remaining
                batch_time_diff = batch_time - prev_batch_time
                print(f"rollouts_ran: {rollouts_sofar + 1}/{total_rollouts}, estimated_time_remaining: {time_remaining:.2f} min, batch_time_diff: {(batch_time_diff * 60):.2f} sec             ", end='\r')
            # -------------------------

            leaf = self.tree_policy()
            winner_code = leaf.random_rollout()
            type(self).marked = set([])
            leaf.backprop(winner_code)
        print("\n*DONE* total training time: ", (time.perf_counter() - start_total)/60, "min")

# Train

Train with specified c_param and win_reward_scalar

In [167]:
c_param = np.sqrt(2)
win_reward_scalar = .75
root = MonteCarloTreeSearchNode(c_param=c_param, win_reward_scalar=win_reward_scalar, reset_state_tbl=True)
root.train(n_rollouts = 100000)

rollouts_ran: 100000/100000, estimated_time_remaining: 0.00 min, batch_time_diff: 0.04 sec             
*DONE* total training time:  0.515468949583313 min


# Test Game

In [168]:
# set iterable node to root
curr = root
# print size of lookup table
print('state_tbl size:', len(MonteCarloTreeSearchNode.state_tbl))
while curr.children:
    #print this node's state
    for i in range(3):
        print(curr.state[i])

    #display child stats for this node
    df = pd.DataFrame(data=[[child.player,len(child.parents),curr.n_visits,child.n_visits,child.results[1],child.results[0],child.results[-1],child.q()/child.n_visits,child.explore_exploit_val(curr)] for child in curr.children],
                      columns=['player','num_parents','parent_n','child_n','w','t','l','q/n','ee_val']).style.highlight_max(color = 'green', axis = 0)
    display(df)
    # set node to the most visited child
    curr = curr.best_child()
# print final game state
for i in range(3):
    print(curr.state[i])

state_tbl size: 3018
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,1,1,100000,27005,6270,16816,20735,0.310283,0.339483
1,1,1,100000,31827,3980,26015,27847,0.263518,0.290416
2,1,1,100000,30114,11995,15922,18119,0.424401,0.452052
3,1,1,100000,36924,12435,18765,24489,0.371059,0.396031
4,1,1,100000,51106,27845,15637,23261,0.497073,0.518299
5,1,1,100000,15223,7754,4442,7469,0.465848,0.50474
6,1,1,100000,74044,39594,26981,34450,0.500667,0.518302
7,1,1,100000,41636,21638,12422,19998,0.475425,0.498942
8,1,1,100000,45834,19855,22806,25979,0.446071,0.468484


(0, 0, 0)
(0, 0, 0)
(1, 0, 0)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,-1,1,74044,41000,6132,14530,34868,0.190527,0.213914
1,-1,1,74044,18538,2457,4592,16081,0.155572,0.190353
2,-1,1,74044,29008,2502,16792,26506,0.184777,0.212581
3,-1,1,74044,21587,1141,13589,20446,0.168185,0.200415
4,-1,1,74044,42263,4766,21272,37497,0.190881,0.213916
5,-1,1,74044,16280,1066,7916,15214,0.149631,0.186745
6,-1,1,74044,25676,393,18823,25283,0.158864,0.188417
7,-1,1,74044,29673,4518,9586,25155,0.186419,0.213909


(0, 0, 0)
(0, -1, 0)
(1, 0, 0)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,1,2,42263,13809,2252,10550,11557,0.283265,0.322542
1,1,2,42263,18450,92,17764,18358,0.196553,0.230533
2,1,2,42263,13329,2524,9895,10805,0.299962,0.339941
3,1,2,42263,16353,2510,10621,13843,0.252688,0.288781
4,1,1,42263,17250,2684,14486,14566,0.292429,0.327571
5,1,2,42263,23326,12483,6588,10843,0.484609,0.51483
6,1,2,42263,29113,14092,14637,15021,0.487789,0.51484


(0, 0, 0)
(0, -1, 0)
(1, 0, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,-1,2,29113,16556,4,10605,16552,0.128304,0.163542
1,-1,2,29113,775,10,7,765,0.012129,0.174998
2,-1,2,29113,9869,8,7971,9861,0.162185,0.207825
3,-1,2,29113,7016,351,3987,6665,0.153677,0.207808
4,-1,2,29113,9874,347,6619,9527,0.162184,0.207813
5,-1,1,29113,18247,326,14595,17921,0.174264,0.20783


(0, 0, 0)
(0, -1, 0)
(1, -1, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,1,2,18247,4555,283,3971,4272,0.224061,0.289698
1,1,1,18247,14538,3,14511,14535,0.199794,0.236533
2,1,1,18247,9374,2442,6613,6932,0.349499,0.395252
3,1,2,18247,11153,565,10588,10588,0.230395,0.272342
4,1,2,18247,10661,2707,7954,7954,0.35235,0.395253


(0, 1, 0)
(0, -1, 0)
(1, -1, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,-1,3,14538,6589,0,6589,6589,0.2,0.253937
1,-1,2,14538,7920,0,7920,7920,0.2,0.249197
2,-1,2,14538,8,5,1,3,0.525,2.072944
3,-1,2,14538,6600,19,6579,6581,0.201667,0.255559


(0, 1, -1)
(0, -1, 0)
(1, -1, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,1,2,7920,2,0,2,2,0.2,3.196189
1,1,2,7920,4,0,4,4,0.2,2.318625
2,1,2,7920,7932,0,7932,7932,0.2,0.247577


(0, 1, -1)
(0, -1, 1)
(1, -1, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,-1,3,7932,3970,0,3970,3970,0.2,0.267255
1,-1,2,7932,3961,0,3961,3961,0.2,0.267332


(-1, 1, -1)
(0, -1, 1)
(1, -1, 1)


Unnamed: 0,player,num_parents,parent_n,child_n,w,t,l,q/n,ee_val
0,1,2,3970,3970,0,3970,3970,0.2,0.264611


(-1, 1, -1)
(1, -1, 1)
(1, -1, 1)


### Hyper param Tuning pipeline
Playing with:  
>         c_param  = 1,1.5,...,9.5  
>    win_reward_scalar = .5,.55,...,.9,.95

In [None]:
roots = []
# c_param = 1,1.5,...,9.5,10
for c_param in np.linspace(1,10,19):
    this_c_param_roots = []
    # win_reward_scalar = .5,.55,...,.9,.95
    for win_reward_scalar in np.linspace(.5,.95,10):
        root = MonteCarloTreeSearchNode(c_param=c_param, win_reward_scalar=win_reward_scalar, reset_state_tbl=True)
        root.train(n_rollouts = 100000)
        this_c_param_roots.append((root,c_param,win_reward_scalar))
    roots.append(this_c_param_roots)