In [86]:
import numpy as np
import random
from monte_carlo_tree_search import MCTS, Node

N = 9                           # Tic Tac Toe size of 3x3
three_values = [3, 21]          # Triple 1 or triple 7

class StatePlayer:
    def __init__(self, **kwargs):
        self.token = "X"
        self.token_value = 1

    def change_player(self):
        if(self.token=="X"):
            self.token = "O"
            self.token_value = 7
        else:
            self.token = "X"
            self.token_value = 1

class State():
    def __init__(self, matrix, player):
        self.matrix = matrix
        self.active_index = 1
        self.player = player

    def get_valid_moves(self):
        return np.where(self.matrix==0)[0]

    # Modified
    def make_move(self, move, new_player=None):
        new_player = new_player if(new_player is not None) else StatePlayer()
        board = State(self.matrix.copy(), new_player)
        board.matrix[move] = board.player.token_value
        board.player.change_player()
        return board

    def make_random_move(self):
        valid_moves = self.get_valid_moves()
        move = random.choice(valid_moves)
        return self.make_move(move, new_player=self.player)

    def get_winner(self):
        matrix2D = self.matrix.reshape((3,3))

        # Check if there is a winner
        for i in range(3):
            # Check lines
            if(np.sum(matrix2D[i]) in three_values):
                return matrix2D[i][0]
            # Check columns
            if(np.sum(matrix2D[:,i]) in three_values):
                return matrix2D[0][i]
        # Check diagonals
        if(np.sum(np.diagonal(matrix2D)) in three_values):
            return matrix2D[0][0]
        if(np.sum(np.diagonal(np.fliplr(matrix2D))) in three_values):
            return matrix2D[0][2]

        # Check if the game is a draw
        if(not np.any(self.matrix==0)):
            return 0

        # Game is not over
        return None

    def is_terminal(self):
        return self.get_winner() is not None

    def get_score(self):
        winner = self.get_winner()
        if(winner == 1):
            return 1
        elif(winner == 7):
            return -1
        return 0

    def __str__(self) -> str:
        board = ''
        for i, cell in enumerate(self.matrix):
            if(i%3==0):
                board += '\n'
            board += str(cell) + ' '
        
        return board


player = StatePlayer()
current_state = State(np.array([1,1,0,
                                7,0,0,
                                7,0,7]), player)
tree = MCTS()
new_board = tree.search(current_state, 10)
print(new_board)


1 1 1 
7 0 0 
7 0 7 
