In [6]:
import numpy as np
import math

In [None]:
class TicTacToe:
    def __init__(self):
        self.rows = 3
        self.cols = 3
        self.total_cells = self.rows * self.cols

    def initialize_board(self):
        return np.zeros((self.rows, self.cols)) 
    
    def get_next_state(self, state, action, player):
        row = action // self.cols 
        col = action % self.cols  
        state[row, col] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8) # 1 is valid, 0 is invalid
    
    def check_win(self, state, action):
        if action == None:
            return False

        row = action // self.cols 
        col = action % self.cols 
        player = state[row, col]

        return (
            np.sum(state[row, :]) == player * self.cols or # check row
            np.sum(state[:, col]) == player * self.rows or # check col
            np.sum(np.diag(state)) == player * self.rows or # check diagonal 
            np.sum(np.diag(np.flip(state, axis=0))) == player * self.rows # check anti-diagonal
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value


In [None]:
game = TicTacToe()
cur_player = 1

state = game.initialize_board()

while True:
    print(state)
    valid_moves = game.get_valid_moves(state)
    print("valid moves", [i for i in range(game.total_cells) if valid_moves[i] == 1])
    
    action = int(input(f'{cur_player}:'))

    if valid_moves[action] == 0:
        print("Invalid move")
        continue

    state = game.get_next_state(state, action, cur_player)

    value, terminated = game.get_value_and_terminated(state, action)   

    if terminated:
        print(state)
        if value == 1:
            print(f"Player {cur_player} wins!")
        else:
            print("It's a draw!")
        break

    cur_player = game.get_opponent(cur_player)

In [None]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)

        self.visits_count = 0
        self.value_sum = 0

    def is_full_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0 # all moves are expanded and all children are created

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child): # formula: Q(s, a) + c * sqrt(ln(N(s)) / N(s, a))
        q_value = 1 - ((child.value_sum / child.visits_count)) + 1 / 2 # 1 - value is used because we are maximizing the value
        return q_value + self.args['C'] * math.sqrt(math.log(self.visits_count) / child.visits_count)

    def expand(self):

class MCTS: 
    def __init__(self, game, args):
        self.game = game
        self.args = args

    def search(self, state):
        root = Node(self.game, self.args, state)

        for search in range(self.args['num_searches']):
            node = root

            while node.is_full_expanded():
                node = node.select()

                value, is_terminated = self.get_value_and_terminated(node.state, node.action_taken) # simulate the game
                value = self.game.get_opponent_value(value) # get the opponent value

                if not is_terminated:
                    node = node.expand()




            # selection
            # expansion
            # simulation
            # backpropagation

        # return best action
            