In [1]:
import numpy as np
class TicTaeToe:

    def __init__(self, state = None):
        
        self.board = np.zeros((3, 3))

    def reset(self):
        self.board = np.zeros((3, 3))
        return self.board

    def place(self, row, col, val):
        self.board[row, col] = val

    def give_valid_moves(self, state):
        return np.argwhere((state.flatten() == 0)).flatten()

    def step(self, action, player):
        row, col = action // 3, action % 3
        self.board[row, col] = player

        result = self.check_winner()
        if result is not None:
            done = True
        elif not(self.board == 0).any():
            done = True
        else:
            done = False

        return self.board, result, done

    def check_winner(self):
        all_conditions = np.concatenate([
            self.board,
            self.board.T,
            [np.diag(self.board)],
            [np.diag(np.fliplr(self.board))]
        ])

        for condition in all_conditions:
            if np.all(condition == 1):
                return 1
            if np.all(condition == -1):
                return -1

        return None

    def show(self, state):
        print(state)


# env = TicTaeToe()
# state = env.reset()

# done = False
# while not done:
#     legal_moves = env.give_valid_moves()
#     action = np.random.choice(legal_moves)
#     board, winner, done = env.step(action)
#     env.show()

#     if done:
#         if winner == 1:
#             print("Player 1 wins!")
#         elif winner == -1:
#             print("Player -1 wins!")
#         else:
#             print("Draw!")

In [2]:
def simulate_game(env, steps = 100000):
    results = {
        1 : 0,
        -1 : 0,
        "draw" : 0}

    for _ in range(steps):
        state = env.reset()
        done = False
        
        while not done:
            legal_moves = env.give_valid_moves()
            action = np.random.choice(legal_moves)
            board, winner, done = env.step(action)
            # env.show()
    
        if done:
            if winner == 1:
                results[1] += 1
            elif winner == -1:
                results[-1] += 1
            else:
                results["draw"] += 1

    print(f"Win by 1 : {results[1]/steps}\nWin by -1 : {results[-1]/steps}\nDraw : {results['draw']/steps}")

# env = TicTaeToe()
# simulate_game(env)

In [3]:
import numpy as np

class Node:

    def __init__(self, state, parent, move, player):
        self.state = state
        self.parent = parent
        self.move = move

        self.player = player
        self.children = {}
        
        self.unexpanded_children = np.argwhere((self.state.flatten() == 0)).flatten()
        
        self.N = 0
        self.W = 0

# can optimize it
    def is_terminal(self):
        all_conditions = np.concatenate([
        self.state,
        self.state.T,
        [np.diag(self.state)],
        [np.diag(np.fliplr(self.state))]])

        for condition in all_conditions:
            if np.all(condition == 1):
                return True
            if np.all(condition == -1):
                return True

        if not (self.state == 0).any():
            return True

        return False

    def is_fully_expanded(self):
        # print(self.unexpanded_children)
        # print(len(self.unexpanded_children) == 0)
        return len(self.unexpanded_children) == 0

    def UCB1(self, c = 2):
        if self.N == 0 or self.parent.N == 0:
            return float('inf')
        return self.W/self.N + np.sqrt(c * np.log(self.parent.N)/self.N)



In [4]:
class MCTS:

    def __init__(self, root_state, root_player):
        self.root = Node(
            state = root_state,
            parent = None,
            move = None,
            player = root_player)
        
    def selection(self):

        current = self.root
        steps = 0
        max_steps = 1000
        while current.is_fully_expanded() and not current.is_terminal():
            max_UCB = 0
            for move, child in current.children.items():
                if child.UCB1() > max_UCB:
                    max_UCB = child.UCB1()
                    best_child = child
                    
            current = best_child
            steps += 1
            if steps > max_steps:
                # helpful debug snapshot â€” avoids printing big arrays
                raise RuntimeError(
                    f"Selection exceeded {max_steps} steps. "
                    f"Last node move={current.move}, player={current.player}, "
                    f"children={len(current.children)}, unexpanded={len(current.unexpanded_children)}, "
                    f"is_term={current.is_terminal()}"
                )

        return current

    # def selection(self):
    #     current = self.root
        
    #     while current.is_fully_expanded() and not current.is_terminal():
    #         if len(current.children) == 0:
    #             break
            
    #         # One-liner: pick child with max UCB1
    #         current = max(current.children.values(), key=lambda child: child.UCB1())
        
    #     return current

    def expansion(self, node):
        if len(node.unexpanded_children)>0 and not node.is_terminal():
            move = np.random.choice(node.unexpanded_children)
            new_state = self.step(node.state, move, node.player)[0]
            child = Node(state = new_state,
                         parent = node,
                         move = move,
                         player = -node.player)

            
            node.children[move] = child
            node.unexpanded_children = node.unexpanded_children[node.unexpanded_children != move]
            # np.delete(node.unexpanded_children, np.where(node.unexpanded_children == move))

            return child
        return None


    def simulation(self, node):
        state = node.state.copy()
        current_player = node.player
        
        while True:
            
            winner = self.check_winner(state)
            
            if winner is not None:
                return winner, state
            legal_moves = np.argwhere((state.flatten() == 0)).flatten()
            if len(legal_moves) == 0:
                return 0, state
            move = np.random.choice(legal_moves)
            state, winner = self.step(state, move, current_player)

            if winner is not None:
                return winner, state
            
            current_player = -current_player


    def backpropagation(self, node, winner):
        current = node
        while current is not None:
            current.N += 1
            # print(f"Current State : \n{current.state}")
            if winner == -current.player:
                # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> Plus One Win\n\n")
                current.W += 1
            elif winner == 0:
                # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> Draw (Plus 0.5)\n\n")
                current.W += 0.5
            # else:
                # print(f"Winner : {winner}, Current Player : {current.player}\nMove : {current.move} --> No Win (Plus 0)\n\n")
                

            current = current.parent


    def search(self, iterations=10):

        for _ in range(iterations):
            # print(_)
            child = self.selection()
            current = self.expansion(child)
            if current is None:
                current = child
            # if not current.is_terminal():
            #     current = self.expansion(current)
            winner, state = self.simulation(current)
            # print(f"Final State : \n{state}, \nWinner : {winner}") 
            self.backpropagation(current, winner)

        best_move = None
        max_moves = 0
        for move, child in self.root.children.items():
            if child.N > max_moves:
                max_moves = child.N
                best_move = move
        return best_move


    def __repr__(self):
        return f"{self.root.state}\n{self.root.parent}\n{self.root.children}\n{self.root.move}"

    
    @staticmethod
    def check_winner(state):
        all_conditions = np.concatenate([
            state,
            state.T,
            [np.diag(state)],
            [np.diag(np.fliplr(state))]
        ])

        for condition in all_conditions:
            if np.all(condition == 1):
                return 1
            if np.all(condition == -1):
                return -1

        if not (state == 0).any():
            return 0

        return None
        
    @staticmethod
    def step(state, action, player):
        state = state.copy()
        row, col = action // 3, action % 3
        state[row, col] = player

        result = MCTS.check_winner(state)
        
        return state, result


        

        

In [7]:
env = TicTaeToe()
state = env.reset()
done = False
player = 1


env.show(state)

while not done:
    root = MCTS(state, player)
    best_move = root.search(iterations=1000)
    
    state, winner, done = env.step(best_move, player)
    env.show(state)
    
    if done:
        if winner == 1:
            print("\nAI won!")
        elif winner == -1:
            print("\nI won!")
        else:
            print("\nDraw!")
        break
    
    player = -player
    
    
    move = int(input("Enter your move (0-8): "))
    
    while move not in env.give_valid_moves(state):
        move = int(input("Enter your move (0-8): "))
    
    state, winner, done = env.step(move, player)
    env.show(state)
    
    if done:
        if winner == 1:
            print("\nAI won!")
        elif winner == -1:
            print("\nI won!")
        else:
            print("\nDraw!")
        break
    
    player = -player

print("\nGAME OVER!!!\n")

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]


Enter your move (0-8):  8


[[ 0.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0. -1.]]
[[ 0.  0.  0.]
 [ 0.  1.  1.]
 [ 0.  0. -1.]]


Enter your move (0-8):  3


[[ 0.  0.  0.]
 [-1.  1.  1.]
 [ 0.  0. -1.]]
[[ 0.  0.  0.]
 [-1.  1.  1.]
 [ 0.  1. -1.]]


Enter your move (0-8):  1


[[ 0. -1.  0.]
 [-1.  1.  1.]
 [ 0.  1. -1.]]
[[ 0. -1.  1.]
 [-1.  1.  1.]
 [ 0.  1. -1.]]


Enter your move (0-8):  6


[[ 0. -1.  1.]
 [-1.  1.  1.]
 [-1.  1. -1.]]
[[ 1. -1.  1.]
 [-1.  1.  1.]
 [-1.  1. -1.]]

Draw!

GAME OVER!!!

