# Imports

In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle
from itertools import product
import time
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

# Model

In [3]:
class MonteCarloTreeSearchNode():
    # class wide lookup table (state:MCTS_Node)
    state_tbl = {}
    # class wide state sizes
    n_rows = 3
    n_cols = 3
    # class wide set to avoid double rewarding nodes in backprop
    marked = set([])

    def __init__(self,c_param=np.sqrt(2), win_reward_scalar=1, state = ((0,0,0),(0,0,0),(0,0,0)), parent=None, player = -1, reset_state_tbl=False):
        self.state = state
        if reset_state_tbl:
            type(self).state_tbl = {}
        # results[0] -> ties, results[1] -> wins, results[-1] -> losses
        self.results = [0,0,0]

        self.parents = [parent] if parent else []
        self.children = []
        self.player = player

        # Hyperparameters
        self.c_param = c_param
        self.win_reward_scalar = win_reward_scalar
        self.tie_reward_scalar = 1 - self.win_reward_scalar

        self.n_visits = 0

        self.is_game_over_node = self.get_winner_code() is not None
        self.unexplored_actions = self.get_legal_actions()

    def get_legal_actions(self, state=None):
        if not state:
            state = self.state
        return [(row, col) \
                for row, col in product(range(type(self).n_rows),range(type(self).n_cols)) \
                if state[row][col] == 0]

    def expand(self):
        # get an unexplored action
        action = self.unexplored_actions.pop(0)
        # apply that action to create the new state (make state list to mutate then tuple for tbl key)
        new_state = self.apply_action(action=action)

        if new_state in type(self).state_tbl:
            child = type(self).state_tbl[new_state]
            child.parents.append(self)
        else:
            child = MonteCarloTreeSearchNode(c_param=self.c_param, 
                                             win_reward_scalar = self.win_reward_scalar, 
                                             state=new_state, 
                                             parent=self, 
                                             player=-self.player)
            type(self).state_tbl[new_state] = child
        self.children.append(child)
        return type(self).state_tbl[new_state]

    def q(self):
        wins = self.results[1]
        ties = self.results[0]
        return self.win_reward_scalar * wins + self.tie_reward_scalar * ties

    def explore_exploit_val(self, parent):
        return (self.q() / self.n_visits) + self.c_param * np.sqrt( np.log(parent.n_visits)/self.n_visits )
    
    def explore_exploit_val_list(self, parent):
        return [child.explore_exploit_val(parent) for child in self.children]

    def expore_exploit(self):
        choices_weights = self.explore_exploit_val_list(parent=self)
        return self.children[np.argmax(choices_weights)]
    
    def best_child(self):
        child_visits = [child.n_visits for child in self.children]
        return self.children[np.argmax(child_visits)]

    def is_fully_expanded(self):
        return len(self.unexplored_actions) == 0
    
    def get_winner_code(self, state=None):
        if not state:
            state = self.state
        game_result = None
        # check rows
        for row in range(type(self).n_rows):
            if state[row][0] == state[row][1] == state[row][2] != 0: 
                game_result = state[row][0]

        # check cols
        for col in range(type(self).n_cols):
            if state[0][col] == state[1][col] == state[2][col] != 0: 
                game_result = state[row][0]
        
        # check diagonals
        if  (state[0][0] == state[1][1] == state[2][2] != 0) or \
            (state[0][2] == state[1][1] == state[2][0] != 0):
            game_result = state[1][1] 
        
        # check if board is full
        elif len(self.get_legal_actions(state)) == 0:
            game_result = 0
        # return None => game is not over
        return game_result
    
    def tree_policy(self):
        if self.is_game_over_node: return self
        if self.is_fully_expanded():
            child = self.expore_exploit()
            return child.tree_policy()
        return self.expand()

    def apply_action(self, action, state=None):
        if not state:
            state = self.state
        new_state = list(list([col for col in row]) for row in state)
        new_state[action[0]][action[1]] = -self.player
        new_state = tuple(tuple(col for col in row) for row in new_state)
        return new_state

    def random_rollout(self, state=None):
        if not state:
            state = self.state

        winner_code = self.get_winner_code(state)
        if winner_code is not None:
            return winner_code

        legal_actions = self.get_legal_actions(state)
        random_action = legal_actions[np.random.randint(len(legal_actions))]
        new_state = self.apply_action(action = random_action, state=state)
        # print(self.get_winner_code(state), state, legal_actions, random_action, new_state, self.get_winner_code(new_state))
        return self.random_rollout(new_state)
    
    def backprop(self, reward):
        self.results[1]  += (reward == self.player)
        self.results[0]  += (reward == 0)
        self.results[-1] += (reward != self.player)
        self.n_visits += 1
        type(self).marked.add(id(self))
        if not self.parents: return
        for parent in self.parents:
            if id(parent) in type(self).marked: continue
            parent.backprop(reward)

        
    def train(self, n_rollouts, rollouts_sofar=0, batch_size=10000):
        # variables for runtime analysis
        total_rollouts = n_rollouts + rollouts_sofar
        batches_remaining = n_rollouts/batch_size
        batch_time = 0
        start = start_total = time.perf_counter()

        for _ in range(n_rollouts):
            # ----time prints calls----
            rollouts_sofar += 1
            if (rollouts_sofar+1) % batch_size == 0:
                prev_batch_time = batch_time
                batch_time = (time.perf_counter()-start)/60 
                start = time.perf_counter()
                batches_remaining -= 1
                time_remaining = batch_time * batches_remaining
                batch_time_diff = batch_time - prev_batch_time
                print(f"rollouts_ran: {rollouts_sofar + 1}/{total_rollouts}, estimated_time_remaining: {time_remaining:.2f} min, batch_time_diff: {(batch_time_diff * 60):.2f} sec             ", end='\r')
            # -------------------------

            leaf = self.tree_policy()
            winner_code = leaf.random_rollout()
            type(self).marked = set([])
            leaf.backprop(winner_code)
        print("\n*DONE* total training time: ", (time.perf_counter() - start_total)/60, "min")

In [10]:
root = MonteCarloTreeSearchNode(c_param=1,win_reward_scalar=.75,reset_state_tbl=True)
while not root.is_game_over_node:
    action = eval(input('action'))
    new_state = root.apply_action(action)
    if new_state in root.state_tbl:
        root = root.state_tbl[new_state]
    else:
        root = MonteCarloTreeSearchNode(c_param=1,win_reward_scalar=.75, state=new_state, reset_state_tbl=True, player=1)
    if root.is_game_over_node:
        break
    root.train(10000)
    root = root.best_child()
    for i in range(3):
        print(root.state[i])
    


rollouts_ran: 10000/10000, estimated_time_remaining: 0.00 min, batch_time_diff: 1.71 sec             
*DONE* total training time:  0.028460616666666282 min
(0, 0, 0)
(0, 0, 0)
(-1, 0, 1)
rollouts_ran: 10000/10000, estimated_time_remaining: 0.00 min, batch_time_diff: 1.47 sec             
*DONE* total training time:  0.02456400333333022 min
(-1, 0, 1)
(0, 0, 0)
(-1, 0, 1)
rollouts_ran: 10000/10000, estimated_time_remaining: 0.00 min, batch_time_diff: 0.53 sec             
*DONE* total training time:  0.008765358333334916 min
(1, 0, 1)
(0, -1, 0)
(-1, 0, 1)
