In [None]:
import random
from copy import copy
from dataclasses import dataclass, field
import time
from enum import Enum
from typing import Self

import numpy as np
from numpy import ndarray

In [None]:
class Player(Enum):
    FIRST = 1
    SECOND = 2


def next_player(player: Player) -> Player | None:
    if player == Player.FIRST:
        return Player.SECOND
    elif player == Player.SECOND:
        return Player.FIRST
    else:
        return None


class BoardStatus(Enum):
    GAME_IN_PROGRESS = -1
    DRAW = 0
    FIRST_WON = 1
    SECOND_WON = 2


class BoardCell(Enum):
    EMPTY = 0
    FIRST = 1
    SECOND = 2

    @staticmethod
    def from_player(player: Player) -> 'BoardCell':
        if player == Player.FIRST:
            return BoardCell.FIRST
        else:
            return BoardCell.SECOND


@dataclass
class Board:
    size: int = field(default=3)
    _cells: ndarray = field(init=False)

    def __copy__(self):
        b = Board(size=self.size)
        b._cells = np.copy(self._cells)
        return b

    def __post_init__(self):
        self._cells = np.zeros((self.size, self.size))

    def _get_board_status(self, player: Player) -> bool:
        # noinspection PyTypeChecker
        player_board = np.full((self.size, self.size), BoardCell.from_player(player).value, dtype=int)
        eq_board = np.equal(self._cells, player_board)
        if np.any(np.all(eq_board, axis=1)):  # horizontal
            return True
        if np.any(np.all(eq_board, axis=0)):  # vertical
            return True
        player_row = player_board[0]
        if np.array_equal(np.diag(self._cells), player_row):  # main diagonal
            return True
        if np.array_equal(np.diag(np.fliplr(self._cells)), player_row):  # secondary diagonal
            return True
        return False

    @property
    def board_status(self) -> BoardStatus:
        if self._get_board_status(Player.FIRST):
            return BoardStatus.FIRST_WON
        elif self._get_board_status(Player.SECOND):
            return BoardStatus.SECOND_WON
        elif np.any(self._cells == 0):
            return BoardStatus.GAME_IN_PROGRESS
        else:
            return BoardStatus.DRAW

    def __getitem__(self, item: tuple[int, int]) -> BoardCell:
        i, j = item
        return BoardCell(self._cells[i, j])

    def __setitem__(self, key: tuple[int, int], value: BoardCell):
        i, j = key
        self._cells[i, j] = value.value


@dataclass
class State:
    board: Board
    last_move_by: Player
    visit_count: int = field(default=0)
    win_score: float = field(default=0)

    @property
    def all_possible_states(self) -> list[Self]:
        out = list()
        if self.board.board_status != BoardStatus.GAME_IN_PROGRESS:
            return out
        opponent = next_player(self.last_move_by)
        for i in range(self.board.size):
            for j in range(self.board.size):
                if self.board[i, j] != BoardCell.EMPTY:
                    b = copy(self.board)
                    b[i, j] = BoardCell.from_player(opponent)
                    out.append(State(b, next_player(opponent), 0, 0))
        return out

    def random_play(self):
        pass

    def add_visit_only(self):
        self.visit_count += 1

    def add_reward(self, reward: float):
        self.win_score += reward
        self.visit_count += 1

@dataclass
class Node:
    state: State
    parent: Self | None = field(default=None)
    children: list[Self] = field(default_factory=list)

    def get_random_child(self) -> Self | None:
        return random.choice(self.children) if self.children else None

    @property
    def child_with_max_score(self) -> Self:
        scores = [child.state.win_score / child.state.visit_count
                  if child.state.visit_count != 0
                  else 0
                  for child in self.children]
        return self.children[np.argmax(scores)]


    def __copy__(self):
        node = Node()
        node.state = copy(self.state)
        node.parent = self.parent
        node.children = list(self.children)

@dataclass
class Tree:
    root: Node = field(default_factory=Node)

In [None]:
def uct_value(total_visit: int, node_win_score: float, node_visit: int) -> float:
    pass


def find_best_node_uct(node: Node) -> Node:
    pass


def select_promising_node(root: Node) -> Node:
    node = root
    while node.children:
        node = find_best_node_uct(node)
    return node


def expand_node(node: Node):
    for state in node.state.all_possible_states:
        new_node = Node(state, node)
        node.children.append(new_node)


def back_propagation(node_to_explore: Node, player: Player, reward: float):
    while node_to_explore:
        if node_to_explore.state.last_move_by == player:
            node_to_explore.state.add_reward(reward)
        else:
            node_to_explore.state.add_visit_only()
        node_to_explore = node_to_explore.parent


def simulate_random_playout(node: Node) -> BoardStatus:
    temp_node = copy(node)


WIN_REWARD = 10

def mcts_find_next_move(board: Board, player: Player, comp_time_seconds: int) -> Board:
    opponent = next_player(player)
    tree = Tree(Node(State(board, opponent)))

    t_end = time.time() + comp_time_seconds
    while time.time() < t_end:
        promising_node = select_promising_node(tree.root)

        if promising_node.state.board.board_status == BoardStatus.GAME_IN_PROGRESS:
            expand_node(promising_node)

        node_to_explore = promising_node.get_random_child() or promising_node

        playout_result = simulate_random_playout(node_to_explore)

        back_propagation(node_to_explore, playout_result, WIN_REWARD)
    winner_node = tree.root.child_with_max_score
    return winner_node.state.board