In [None]:
import numpy as np
EMPTY = 0
PLAYER1 = 1
PLAYER2 = -1

In [None]:
class Gomoku:
    def __init__(self, size=15):
        self.size = size

    def get_initial_state(self):
        return np.zeros((self.size, self.size)).astype(np.int8)

    def get_next_state(self, state, action, player):
        row = action // self.size
        col = action % self.size
        if state[row, col] != EMPTY:
            raise ValueError("Invalid action")
        state[row, col] = player
        return state
    
    def get_moves(self, state):
        return (state.reshape(-1) == EMPTY).astype(np.uint8)

    def check_win(self, state, action):
        if action is None:
            return False
        
        row = action // self.size
        col = action % self.size
        player = state[row, col]
        if player == EMPTY:
            return False

        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]

        for dr, dc in directions:
            count = 1

            # Create an array of indices for positive direction
            indices = np.array([(row + i * dr, col + i * dc) for i in range(1, 5)])
            valid_indices = (indices[:, 0] >= 0) & (indices[:, 0] < self.size) & (indices[:, 1] >= 0) & (indices[:, 1] < self.size)
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            # Create an array of indices for negative direction
            indices = np.array([(row - i * dr, col - i * dc) for i in range(1, 5)])
            valid_indices = (indices[:, 0] >= 0) & (indices[:, 0] < self.size) & (indices[:, 1] >= 0) & (indices[:, 1] < self.size)
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            if count >= 5:
                return True

        return False
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(state == EMPTY) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return PLAYER1 if player == PLAYER2 else PLAYER2

    def get_opponent_value(self, player):
        return -player

    def change_perspective(self, state, player):
        return state * player
    
    def print(self, state):
        board_str = ''
        for row in range(self.size):
            row_str = ' '.join(str(state[row, col]) if state[row, col] != EMPTY else '.' for col in range(self.size))
            board_str += row_str + '\n'
        board_str = board_str.replace('-1', 'O').replace('1', 'X')
        print(board_str)

In [None]:
game = Gomoku()
player = PLAYER1

state = game.get_initial_state()
tiles = game.size * game.size

while True:
    break
    game.print(state)
    valid_moves = game.get_moves(state)
    print("Valid moves:", [(i // tiles, i % tiles) for i in range(tiles) if valid_moves[i] == 1])
    user_input = input("Enter action: ")
    action = int(user_input.split(',')[0]) * game.size + int(user_input.split(',')[1])
    
    if valid_moves[action] == 0:
        print("Invalid move")
        continue
    
    state = game.get_next_state(state, action, player)
    value, terminated = game.get_value_and_terminated(state, action)
    if terminated:
        if value == 1:
            print("Player", player, "wins")
        else:
            print("Draw")
        break

    player = game.get_opponent(player)

In [None]:
class MCTSArgs:
    def __init__(self):
        self.num_searches = 1000
        self.c = 1.41
        
        
class Node:
    def __init__(self, game, args, state, parent=None, action=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action = action
        
        self.children = []
        self.expandable_actions = game.get_moves(state)
        
        self.visit_count = 0
        self.total_value = 0
        
    def is_fully_expanded(self):
        return np.sum(self.expandable_actions) == 0 and len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
                
        return best_child

    def get_ucb(self, child):
        q_value = 1 - (child.total_value / child.visit_count + 1) / 2
        return q_value + self.args.c * np.sqrt(np.log(self.visit_count) / child.visit_count)

    def expand(self):
        action = np.random.choice(np.where(self.expandable_actions == 1)[0])
        self.expandable_actions[action] = 0
        
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, PLAYER1)
        child_state = self.game.change_perspective(child_state, player=PLAYER2)
        
        child = Node(self.game, self.args, child_state, parent=self, action=action)
        self.children.append(child)
        
        return child

    def simulate(self):
        value, terminated = self.game.get_value_and_terminated(self.state, self.action)
        value = self.game.get_opponent_value(value)
        if terminated:
            return value
        
        rollout_state = self.state.copy()
        rollout_player = PLAYER1
        while True:
            valid_moves = self.game.get_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, terminated = self.game.get_value_and_terminated(rollout_state, action)
            if terminated:
                if rollout_player == PLAYER1:
                    value = self.game.get_opponent_value(value)
                return value
            rollout_player = self.game.get_opponent(rollout_player)

    def backpropagate(self, value):
        self.visit_count += 1
        self.total_value += value
        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)
        
    
class MCTS:
    def __init__(self, game, args):
        self.game = game
        self.args = args
        
    def search(self, state):
        root = Node(self.game, self.args, state)
        for search in range(self.args.num_searches):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()
            
            value, terminated = self.game.get_value_and_terminated(node.state, node.action)
            value = self.game.get_opponent_value(value)
            
            if not terminated:
                node = node.expand()
                value = node.simulate()
            
            node.backpropagate(value)
            
        action_probs = np.zeros(self.game.size * self.game.size)
        for child in root.children:
            action_probs[child.action] = child.visit_count
        action_probs = action_probs / np.sum(action_probs)
        
        return action_probs
        

In [None]:
game = Gomoku(5)
player = PLAYER1

args = MCTSArgs()
mcts = MCTS(game, args)

state = game.get_initial_state()
tiles = game.size * game.size

while True:
    game.print(state)
    if player == PLAYER1:
        valid_moves = game.get_moves(state)
        print("Valid moves:", [(i // tiles, i % tiles) for i in range(tiles) if valid_moves[i] == 1])
        user_input = input("Enter action: ")
        action = int(user_input.split(',')[0]) * game.size + int(user_input.split(',')[1])
        
        if valid_moves[action] == 0:
            print("Invalid move")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_action_probs = mcts.search(state)
        action = np.argmax(mcts_action_probs)
    
    
    state = game.get_next_state(state, action, player)
    value, terminated = game.get_value_and_terminated(state, action)
    if terminated:
        if value == 1:
            print("Player", player, "wins")
        else:
            print("Draw")
        break

    player = game.get_opponent(player)