In [3]:
import numpy as np

In [4]:
class TicTacToe:
    def __init__(self):
        self.rows = 3
        self.cols = 3
        self.total_cells = self.rows * self.cols

    def initialize_board(self):
        return np.zeros((self.rows, self.cols)) 
    
    def get_next_state(self, state, action, player):
        row = action // self.cols 
        col = action % self.cols  
        state[row, col] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8) # 1 is valid, 0 is invalid
    
    def check_win(self, state, action):
        row = action // self.cols 
        col = action % self.cols 
        player = state[row, col]

        return (
            np.sum(state[row, :]) == player * self.cols or # check row
            np.sum(state[:, col]) == player * self.rows or # check col
            np.sum(np.diag(state)) == player * self.rows or # check diagonal 
            np.sum(np.diag(np.flip(state, axis=0))) == player * self.rows # check anti-diagonal
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player


In [5]:
game = TicTacToe()
cur_player = 1

state = game.initialize_board()

while True:
    print(state)
    valid_moves = game.get_valid_moves(state)
    print("valid moves", [i for i in range(game.total_cells) if valid_moves[i] == 1])
    
    action = int(input(f'{cur_player}:'))

    if valid_moves[action] == 0:
        print("Invalid move")
        continue

    state = game.get_next_state(state, action, cur_player)

    value, terminated = game.get_value_and_terminated(state, action)   

    if terminated:
        print(state)
        if value == 1:
            print(f"Player {cur_player} wins!")
        else:
            print("It's a draw!")
        break

    cur_player = game.get_opponent(cur_player)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[1. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves [1, 2, 3, 4, 5, 6, 7, 8]
[[ 1.  0.  0.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
valid moves [1, 2, 3, 5, 6, 7, 8]
[[ 1.  0.  1.]
 [ 0. -1.  0.]
 [ 0.  0.  0.]]
valid moves [1, 3, 5, 6, 7, 8]
[[ 1.  0.  1.]
 [ 0. -1.  0.]
 [ 0.  0. -1.]]
valid moves [1, 3, 5, 6, 7]
[[ 1.  0.  1.]
 [ 1. -1.  0.]
 [ 0.  0. -1.]]
valid moves [1, 5, 6, 7]
[[ 1.  0.  1.]
 [ 1. -1.  0.]
 [ 0. -1. -1.]]
valid moves [1, 5, 6]
Invalid move
[[ 1.  0.  1.]
 [ 1. -1.  0.]
 [ 0. -1. -1.]]
valid moves [1, 5, 6]
Invalid move
[[ 1.  0.  1.]
 [ 1. -1.  0.]
 [ 0. -1. -1.]]
valid moves [1, 5, 6]
[[ 1.  1.  1.]
 [ 1. -1.  0.]
 [ 0. -1. -1.]]
Player 1 wins!


In [None]:
class MCTS: 
    def __init__(self, game, args):
        self.game = game
        self.args = args