In [3]:
import dataclasses
from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np

import flax
import flax.linen as nn

In [4]:
@dataclasses.dataclass(frozen=True)
class Gomoku:
    rows: int = 3
    cols: int = 3
    n: int = 3

    def action_size(self):
        return self.rows * self.cols

    def init_state(self):
        return np.zeros((self.rows, self.cols))

    def next_state(self, state: np.ndarray, action: int, player: int):
        row = action // self.cols
        column = action % self.cols
        state[row, column] = player
        return state

    def valid_moves(self, state: np.ndarray):
        return state.reshape(-1) == 0

    def check_win(self, state: np.ndarray, action: int):
        if action == None:
            return False

        row = action // self.cols
        column = action % self.cols
        player = state[row, column]

        return (
            np.sum(state[row, :]) == player * self.cols
            or np.sum(state[:, column]) == player * self.rows
            or np.sum(np.diag(state)) == player * self.rows
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.rows
        )

    def compute_value_and_terminated(self, state: np.ndarray, action: int):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.valid_moves(state)) == 0:
            return 0, True
        return 0, False

In [5]:
@dataclasses.dataclass(frozen=True)
class Args:
    num_searches: int
    C: float


@dataclasses.dataclass
class MCTSNode:
    game: Gomoku
    args: Args
    state: np.ndarray
    action_taken: int | None = None
    parent: MCTSNode | None = None

    def __post_init__(self):
        self.children = []
        self.expandable_moves = self.game.valid_moves(self.state)
        self.visit_count = 0
        self.value_sum = 0

    def fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self):
        return max(self.children, key=self.ucb)

    def ucb(self, child):
        n_s = self.visit_count
        n_s_a = child.visit_count
        Q = child.value_sum / child.visit_count
        Q = 1 - (Q + 1) / 2  # Normalize Q to [0, 1]
        return Q + self.args.C * np.sqrt(np.log(n_s) / n_s_a)

    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0

        child_state = self.state.copy()
        child_state = -self.game.next_state(child_state, action, 1)

        child = MCTSNode(
            game=self.game,
            args=self.args,
            state=child_state,
            action_taken=action,
            parent=self,
        )
        self.children.append(child)
        return child

    def simulate(self):
        value, is_terminal = self.game.compute_value_and_terminated(
            self.state, self.action_taken
        )

        rollout_player = 1
        rollout_state = self.state.copy()
        while not is_terminal:
            valid_moves = self.game.valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.compute_value_and_terminated(
                rollout_state, action
            )
            rollout_player = -rollout_player

        return -rollout_player * value

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            self.parent.backpropagate(-value)


@dataclasses.dataclass
class MCTS:
    game: Gomoku
    args: Args

    def search(self, state):
        root = MCTSNode(game=self.game, args=self.args, state=state)

        for search in range(self.args.num_searches):
            node = root

            while node.fully_expanded():
                node = node.select()

            value, is_terminal = self.game.compute_value_and_terminated(
                node.state, node.action_taken
            )
            value = -value

            if not is_terminal:
                node = node.expand()
                value = node.simulate()

            node.backpropagate(value)

        action_probs = np.zeros(self.game.action_size())
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs

In [None]:
tictactoe = Gomoku(rows=3, cols=3, n=3)
player = 1
args = Args(C=1.41, num_searches=1000)
mcts = MCTS(tictactoe, args)
state = tictactoe.init_state()

while True:
    print(state)

    if player == 1:
        valid_moves = tictactoe.valid_moves(state)
        print(
            "valid_moves",
            [i for i in range(tictactoe.action_size()) if valid_moves[i] == 1],
        )
        action = int(input(f"{player}:"))

        if valid_moves[action] == 0:
            print("action not valid")
            continue

    else:
        neutral_state = state * player
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

    state = tictactoe.next_state(state, action, player)

    value, is_terminal = tictactoe.compute_value_and_terminated(state, action)

    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break

    player = -player

In [4]:
class ResNet(nn.Module):
    def __init__(self, in_channels, num_blocks: int, num_hidden: int):
        super().__init__()
        self.start_block = nn.Sequential([
        ])
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, padding=1)
        self.blocks = nn.ModuleList(
            [ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_blocks)]
        )
        self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        for block in self.blocks:
            x = block(x)
        x = F.relu(self.conv2(x))
        return x