In [1]:
import numpy as np
import ipywidgets as widgets
from tqdm import tqdm
import random
import matplotlib.pyplot as plt

In [254]:
class MCTSNode:
    def __init__(self, state, parent_node):
        self.state = state
        self.parent_node = parent_node
        self.total_visits = 0
        self.total_score = 0
        self.children_nodes = []
        self.player = self.check_player(state)
        self.terminate_state = False
        self.all_children_nodes = False

    def check_player(self, state):
        if np.sum(state==1) > np.sum(state==2):
            return 2
        else:
            return 1

class MCTS:
    def __init__(self, exploration_constant = 2):
        self.exploration_constant = exploration_constant

    def is_terminal(self, board):
        return not np.any(board == 0)

    def is_win(self, state, player):
        col_win = (np.sum(state == player, axis=0) == 3).any()
        row_win = (np.sum(state == player, axis=1) == 3).any()
        diagonal_win = np.trace(state == player) == 3
        opposite_diagonal = np.trace(np.fliplr(state) == player) == 3
        return col_win or row_win or diagonal_win or opposite_diagonal

    def select(self, curr_node, should_explore=True):
        while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):
            if curr_node.all_children_nodes:
                highest_value = -float("inf")
                chosen_child = None

                # loop all children nodes and take the best one according to heuristic
                for child in curr_node.children_nodes:
                    # compute UCB1 score
                    child_val = (child.total_score/child.total_visits) + should_explore*self.exploration_constant*np.sqrt(np.log(curr_node.total_visits)/child.total_visits)

                    # if it has highest value then store it as the chosen child from this step
                    if child_val > highest_value:
                        highest_value = child_val
                        chosen_child = child

                # choose highest value move
                return chosen_child

            else:
                # if not all children nodes accessible then expand the node first
                return self.expand(curr_node)

        print("should never come here")

    def expand(self, curr_node):
        states = self.generate_next_states(curr_node)

        for state in states:
            # unroll children states, and ensure we do not expand to a state we have 
            # already expanded to in a previous iteration
            if str(state) not in [str(b.state) for b in curr_node.children_nodes]:
                child_node = MCTSNode(state, curr_node)
                curr_node.children_nodes.append(child_node)
                
                # if the num children nodes equal the amount of possible next states
                # we have explored all child nodes for this state
                if len(states) == len(curr_node.children_nodes):
                    curr_node.all_children_nodes = True

                return child_node


    def simulate(self, curr_node, computer_playing):
        opponent = 1 if computer_playing == 2 else 1
        
        while not is_terminal(curr_node.state) and not (self.is_win(curr_node.state, 1) or self.is_win(curr_node.state, 2)):
            next_states = self.generate_next_states(curr_node)
            curr_node = MCTSNode(next_states[random.randint(0, len(next_states) - 1)], curr_node)
        
        if self.is_win(curr_node.state, player=computer_playing):
            return 1
        elif self.is_win(curr_node.state, player=opponent):
            return -1
        else:
            return 0

        
    def backpropagate(self, node, score):
        while node:
            node.total_visits += 1
            node.total_score += score
            node = node.parent_node
    
    def generate_next_states(self, curr_node):
        player = curr_node.player
        curr_state = curr_node.state
        next_states = []
        for i in range(3):
            for j in range(3):
                if curr_state[i,j] == 0:
                    to_append = np.copy(curr_state)
                    to_append[i,j] = player
                    next_states.append(to_append)
        return next_states


    def get_move(self, root, num_iterations=1000):
        for it in range(num_iterations):
            curr_node = self.select(root)
            obtained_value = self.simulate(curr_node, root.player)
            self.backpropagate(curr_node, obtained_value)
        
        chosen_move = self.select(root, should_explore=False)
        return chosen_move

In [263]:
a = np.zeros((3,3))
root = MCTSNode(a, None)
mc = MCTS()

for i in range(9):
    row_col = input("Row and column to place with ,").split(",")
    state = np.copy(root.state)
    state[int(row_col[0]), int(row_col[1])] = 1
    next_node = MCTSNode(state, root)
    
    root = mc.get_move(next_node)
    print(root.state)

print("Final: {root.state}")
    

Row and column to place with ,1,1
[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 2.]]
Row and column to place with ,0,0
[[1. 0. 0.]
 [0. 1. 0.]
 [2. 0. 2.]]
Row and column to place with ,2,1
[[1. 2. 0.]
 [0. 1. 0.]
 [2. 1. 2.]]
Row and column to place with ,1,2
[[1. 2. 0.]
 [2. 1. 1.]
 [2. 1. 2.]]
Row and column to place with ,0,2
should never come here


AttributeError: 'NoneType' object has no attribute 'state'