In [2]:
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count), dtype=int)
    
    def get_next_state(self, state, action, player):
        new_state = state.copy()
        row = action // self.column_count
        col = action % self.column_count
        new_state[row, col] = player
        return new_state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        if action == None:
            return False

        row = action // self.column_count
        col = action % self.column_count
        player = state[row, col]
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, col]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player

 


In [None]:
import math
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)

        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_child:
                best_child = child
                best_ucb = ucb
        return best_child
    
    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self):
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0

        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)

        child_state = self.game.change_perspective(child_state, player=-1)

        child = Node(self.game, self.args, child_state, self, action)

        self.children.append(child)
        return child
    
    def simulate(self):
        value, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)

        if is_terminal:
            return value
        rollout_state = self.state.copy()
        rollout_player = 1

        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)

            if is_terminal:
                if rollout_player == -1:
                    value = self.game.get_opponent_value(value)
                return value
            
            rollout_player = self.game.get_opponent(rollout_player)
        
    def backpropogate(self, value):
        self.value_sum += value
        self.visit_count += 1

        value = self.game.get_opponent_value(value)

        if self.parent is not None:
            self.parent.backpropogate(value)

class MCTS:
    def __init__(self, game, args):
        self.game = game
        self.args = args

    def search(self, state):
        # define root node
        root = Node(self.game, self.args, state)

        for search in range(self.args['num_searches']):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
            # expansion
                node = node.expand()
            # simulation
                value = node.simulate()
            # backpropogation
            
            node.backpropogate(value)


        action_probs = np.zeros(self.game.action_size)
        for child in self.children:
            action_probs[child.action_taken] = child.visit_count

        action_probs /= np.sum(action_probs)
        return action_probs

In [6]:
import ipywidgets as widgets
from IPython.display import display
import numpy as np

# Setup
ttt = TicTacToe()
player = 1
args = {
    'C': 1.41,
    'num_searches': 1000
}
mcts = MCTS(ttt, args)
state = ttt.get_initial_state()

# Create 9 blank buttons for the grid
buttons = [widgets.Button(description='',
                          layout=widgets.Layout(width='60px', height='60px'))
           for _ in range(ttt.action_size)]

grid = widgets.GridBox(buttons,
                       layout=widgets.Layout(
                           grid_template_columns='repeat(3, 60px)',
                           grid_gap='5px',
                       ))

status = widgets.HTML(value=f"<b>Player {player}’s turn</b>")

def on_click(b):
    global state, player

    idx = buttons.index(b)
    valid_moves = ttt.get_valid_moves(state)

    if not valid_moves[idx]:
        status.value = f"<span style='color:red;'>Cell is occupied!</span>"
        return

    # Player 1 move
    state = ttt.get_next_state(state, idx, player)
    b.disabled = True
    b.description = 'X'

    value, is_terminal = ttt.get_value_and_terminated(state, idx)

    if is_terminal:
        if value == 1:
            status.value = f"<h3>Player {player} wins! 🎉</h3>"
        else:
            status.value = "<h3>Draw!</h3>"
        for btn in buttons:
            btn.disabled = True
        return

    # Player -1 (AI) move via MCTS
    player = -player
    status.value = f"<b>Player {player} is thinking...</b>"

    neutral_state = ttt.change_perspective(state, player)
    mcts_probs = mcts.search(neutral_state)
    ai_action = int(np.argmax(mcts_probs))
    state = ttt.get_next_state(state, ai_action, player)

    buttons[ai_action].disabled = True
    buttons[ai_action].description = 'O'

    value, is_terminal = ttt.get_value_and_terminated(state, ai_action)

    if is_terminal:
        if value == 1:
            status.value = f"<h3>Player {player} wins! 🎉</h3>"
        else:
            status.value = "<h3>Draw!</h3>"
        for btn in buttons:
            btn.disabled = True
    else:
        player = -player
        status.value = f"<b>Player {player}’s turn</b>"

for btn in buttons:
    btn.on_click(on_click)

display(grid, status)



GridBox(children=(Button(layout=Layout(height='60px', width='60px'), style=ButtonStyle()), Button(layout=Layou…

HTML(value='<b>Player 1’s turn</b>')

AttributeError: 'numpy.ndarray' object has no attribute 'game'