MCTS implementation of tic-tac-toe game

- For selection we will use UCT algorithm
- For rollout policy, we will use a simple random rollout (Keep in mind that this lower the performance)

In [1]:
import time
import copy
import random
import math
import numpy as np

In [2]:
class TicTacToe:
    def __init__(self,grid,first) -> None:
        self.grid = grid
        self.board = [[0 for _ in range(grid)] for _ in range(grid)]
        self.starting_player = 'AI' if first == 1 else 'Player'
        self.current_player = 1 if first == 1 else -1
    
    def get_legal_moves(self):
        legal_moves = []
        for i in range(self.grid):
            for j in range(self.grid):
                if self.board[i][j] == 0:
                    legal_moves.append((i,j))
        return legal_moves
    
    def make_move(self,move):
        if move not in self.get_legal_moves():
            return 'Invalid move!'
        i,j = move
        self.board[i][j] = self.current_player
        self.current_player = -1 if self.current_player == 1 else 1
    
    def is_winner(self):
    # Check Row
        for row in self.board:
            if all(cell == -1*self.current_player for cell in row):
                return True
        for col in range(len(self.board)):
            if all(self.board[row][col] == -1*self.current_player for row in range(self.grid)):
                return True
        if all(self.board[i][i] == -1*self.current_player for i in range(self.grid)) or all(self.board[i][self.grid-1-i] == -1*self.current_player for i in range(len(self.board))):
            return True
        return False
      
    def check_win(self):
        if self.get_legal_moves() == []:
            return 0
        elif self.is_winner():
            return 1 if self.current_player == -1 else -1
        else:
            return 2
    
    def is_game_over(self):
        if self.check_win() == 0:
            return (1,'Draw!')
        elif self.check_win() == 1:
            return (1,'AI win!')
        elif self.check_win() == -1:
            return (1,'Player win!')
        elif self.check_win() == 2:
            return (0,'Play on!')
        
    def visualize_board(self):
        for row in self.board:
            print('|',end='')
            print("|".join(["X" if cell == 1 else "O" if cell == -1 else " " for cell in row]),end='')
            print('|')
        print('-'*((len(self.board)*2)+1))

    def get_current_player(self):
        return 'Player' if self.current_player == -1 else 'AI'

In [3]:
#State is the current TicTacToe object

class Node:
    def __init__(self,state,c_param,parent=None) -> None:
        self.state = copy.deepcopy(state)
        self.parent = parent
        self.visits = 0
        self.wins = 0
        self.c_param = c_param
    
        self.children = []
        self.legal_move = state.get_legal_moves()

    def is_expandable(self):
        if self.legal_move == []:
            return 0
        else:
            return 1
    
    def rollout_policy(self):
        return self.legal_move.pop(random.randint(0,len(self.legal_move)-1))
    
    @property
    def ucb_score(self):
        if self.parent != None:
            return (self.wins/self.visits) + (self.c_param * math.sqrt(math.log(self.parent.visits)/self.visits))
        else:
            return 'This is a parent node!'
        
    def best_child(self):
        choice = np.array([self.children[i].ucb_score for i in range(len(self.children))])
        return self.children[np.argmax(choice)]
    

        

In [4]:
class MCTS:
    def __init__(self,root,n_iterations=100000,c_param=1.4) -> None:
        self.root = root
        self.n_iterations = n_iterations
        self.c_param = c_param
        self.count = 0
    
    def algorithm(self):
        while self.count < self.n_iterations   :
            self.count += 1
            # global all_visits 
            # all_visits = self.count
            # print(self.count)
            start_node = self.root
            start_node.visits = self.count 
            # print('This is current root')
            # start_node.state.visualize_board()
            current_node = self.selection(start_node)
            # print(current_node.ucb_score)
            # current_node.state.visualize_board()
            # print(current_node.state.board)
            # print(current_node.ucb_score)
            if current_node.state.is_game_over()[0]:
                score = 1 if current_node.state.check_win() == 1 else 0
                self.back_propagation(current_node,score)
                continue
            self.expansion(current_node)
        return self.root.best_child(),self.root

    def selection(self,node):
        if node.state.is_game_over()[0]:
            return node
        if node.is_expandable():
            return node
        else:
            # node.visits += 1
            # print(node.best_child().state.board)
            return self.selection(node.best_child()) 

    def expansion(self,node):
        if node.state.is_game_over()[0]:  # Check if the game is already over
            return
        node.visits += 1
        tmp_board = copy.deepcopy(node.state)
        tmp_board.make_move(node.legal_move.pop())
        tmp_node = Node(tmp_board,self.c_param)
        tmp_node.visits += 1
        score = self.simulation(tmp_node)
        tmp_node.wins += score
        node.wins += score
        tmp_node.parent = node
        node.children.append(tmp_node)

    def simulation(self,node):
        tmp_node = copy.deepcopy(node)
        # tmp_node.state.visualize_board()
        while not tmp_node.state.is_game_over()[0]:
            k = tmp_node.rollout_policy()
            # print(k)
            tmp_node.state.make_move(k)
            # tmp_node.state.visualize_board()
        score = 1 if tmp_node.state.check_win() == 1 else 0
        # print('score is',score)
        return score
    
    def back_propagation(self,node,is_win):
        if is_win == 1:
            node.visits += 1
            node.wins += 1
            if node.parent:
                self.back_propagation(node.parent,1)
        elif is_win == 0:
            node.visits += 1
            node.wins -= 10
            if node.parent:
                self.back_propagation(node.parent,0)
        


In [5]:
# Test
# game = TicTacToe(3,1)
# # game.make_move((0,0))
# test_node = Node(game,2)
# # print(test_node.is_expandable())
# # print('legal move',test_node.legal_move)
# test_mcts = MCTS(test_node)
# result = test_mcts.algorithm()

In [6]:
# result.state.visualize_board()

In [7]:
# result.state

In [8]:
# result.state.current_player

In [9]:
def game_on(grid):
    while True:
        x = int(input('Please select who will go first (0 is you, 1 is AI): '))
        if x not in [0,1]:
            print('Invalid input!')
            continue
        else:
            break
    game = TicTacToe(grid,x)
    game_over = False
    computation_time = []

    game.visualize_board()
    while not game_over:
        time.sleep(0.2)
        # game.visualize_board()
        if x == 0:
            print('Your turn!')
            while True:
                time.sleep(0.1)
                pos_x = int(input('Please input row position: '))
                pos_y = int(input('Please input column position: '))
                if pos_x > grid or pos_y > grid or pos_x < 1 or pos_y < 1:
                    print('Invalid position!')
                    time.sleep(0.1)
                    continue
                elif (pos_x-1,pos_y-1) not in game.get_legal_moves():
                    print('Invalid position!')
                    time.sleep(0.1)
                    continue
                else:
                    break
            game.make_move((pos_x-1,pos_y-1))
            x += 1
        elif x == 1:
            start = time.time()
            root_node = Node(game,2)
            ai_mcts = MCTS(root_node)
            ai_move,root = ai_mcts.algorithm()
            # print([root.children[i].ucb_score for i in range(len(root.children))])
            # print(root.best_child().state.board)
            # print(ai_move.state.board)
            game = ai_move.state
            end = time.time()
            total = end - start
            computation_time.append(total)
            print('AI play!')
            game.visualize_board()
            x -= 1
        
        game_over,text = game.is_game_over()
        # print(text)

    game.visualize_board()
    print(text)
    print('...............................................',end='\n')
    print('Average Computation time',np.average(computation_time))


In [10]:
grid = int(input('Please insert grid size: '))
game_on(grid)

| | | |
| | | |
| | | |
-------
Your turn!
AI play!
|O| | |
| |X| |
| | | |
-------
Your turn!
AI play!
|O| | |
|O|X| |
|X| | |
-------
Your turn!
AI play!
|O|X|O|
|O|X| |
|X| | |
-------
Your turn!
AI play!
|O|X|O|
|O|X| |
|X|O|X|
-------
Your turn!
|O|X|O|
|O|X|O|
|X|O|X|
-------
Draw!
...............................................
Average Computation time 13.434150695800781


In [11]:
# # for testing
# game = TicTacToe(3,0)
# game.make_move((0,0))
# test_node = Node(game,2)
# print(test_node.is_expandable())
# print('legal move',test_node.legal_move)
# test_mcts = MCTS(10,test_node)
# for _ in range(20):
#     test_mcts.expansion(test_mcts)
#     print('win count',test_mcts.root.wins)
#     print('legal move',test_mcts.root.legal_move)
#     print('child',test_mcts.root.children)
#     print('visit count',test_mcts.root.visits)
#     print(game.board)
#     print(test_node.state.board)

In [12]:
grid = int(input('Please insert grid size: '))
game_on(grid)

| | | | |
| | | | |
| | | | |
| | | | |
---------
Your turn!
AI play!
|O| | | |
| | | | |
| | | |X|
| | | | |
---------
Your turn!
AI play!
|O| | | |
|O| | | |
| | |X|X|
| | | | |
---------
Your turn!
AI play!
|O| | | |
|O| | | |
|O| |X|X|
|X| | | |
---------
Your turn!
AI play!
|O|O| |X|
|O| | | |
|O| |X|X|
|X| | | |
---------
Your turn!
AI play!
|O|O| |X|
|O|O| | |
|O| |X|X|
|X| | |X|
---------
Your turn!
AI play!
|O|O| |X|
|O|O|X|O|
|O| |X|X|
|X| | |X|
---------
Your turn!
AI play!
|O|O| |X|
|O|O|X|O|
|O|O|X|X|
|X| |X|X|
---------
Your turn!
AI play!
|O|O|O|X|
|O|O|X|O|
|O|O|X|X|
|X|X|X|X|
---------
|O|O|O|X|
|O|O|X|O|
|O|O|X|X|
|X|X|X|X|
---------
Draw!
...............................................
Average Computation time 32.711070477962494
