In [3]:
from enum import Enum
from copy import deepcopy
from math import log
import pickle

In [4]:
def back_to_start_state(board, nb_moves_beginning):
    while len(board.moves) > nb_moves_beginning:
        board.undo_last_move()

In [5]:
class Node:
    
    def __init__(self, board):
        self.board = board
        self.wins = 0
        self.draws = 0
        self.visits = 0
        self.parents = set()
        
    def add_parent(self, parent):
        if isinstance(parent, Board):
            parent_representation = parent.get_representation()
            self.parents.add(parent_representation)
        else:
            self.parents.add(parent)
            
    def update(self, player, winner):
        if player == winner:
            self.wins += 1
        elif winner == 0:
            self.draws += 1
        
        self.visits += 1
            
    def calculate_parents_visits(self):
        total_parents_visits = 0
        for parent in self.parents:
            total_parents_visits += parent.visits
        return total_parents_visits
    
    def calculate_winrate(self):
        if self.visits == 0:
            return 0
        
        return (self.wins + self.draws) / self.visits
            
    def calculate_uct(self, c):
        if self.visits == 0:
            return float('inf')
        
        winrate = (self.wins + self.draws) / self.visits #we count a draw as a win because, under best play, a draw is the best result one can get
        
        parents_visits = self.calculate_parents_visits()
        exploration = c * (log(parents_visits) / self.visits)**(1/2)
        
        return winrate + exploration

In [6]:
class direction(Enum):
    ROW = 1
    COL = 2
    DIAG_TOP_BOTTOM = 3
    DIAG_BOTTOM_TOP = 4

class Board:
    ##### Players #####
    ## 1 = player X
    ## -1 = player O

    ##### Game State #####
    ## 0 and 9 moves played: draw
    ## -1: player O wins the game
    ## 1: player X wins the game


    def __init__(self):
        self.grid = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
        self.moves = []
        self.gameState = 0
        self.winningMove = None
        
    def copy(self):
        board_copy = Board()
        board_copy.grid = deepcopy(self.grid)
        board_copy.moves = deepcopy(self.moves)
        board_copy.gameState = self.gameState    
        board_copy.winningMove = self.winningMove    
        return board_copy

    def display_cell(self, i, j):
        if self.grid[i][j] == 0:
            print('   ', end = '')
        elif self.grid[i][j] == 1:
            print(' X ', end = '')
        else:
            print(' O ', end = '')

    def display(self, addSpace=False):
        for i in range(3):
            for j in range(3):
                self.display_cell(i, j)
                if j < 2:
                    print('|', end='')
                else:
                    print()
            if i < 2:
                print('-----------')
        if addSpace:
            print()

    def display_player_name(self, player):
        if player == 1:
            return "X"
        return "O"

    def updateGameState(self, i, j, player):
        #i, j stand for the coordinates of the last move played 
        #player is the player that played the last move
        if len(self.moves) == 9 or self.gameState != 0:
            pass
        else:
            #check if the player has won on the row
            if self.playerWins_row(i, j, player):
                self.gameState = player
                self.winningMove = ((i, j), direction.ROW) 
                return
            if self.playerWins_column(i, j, player):
                self.gameState = player
                self.winningMove = ((i, j), direction.COL) 
                return
            
            won_on_diagonal = self.playerWins_diag(player)
            if won_on_diagonal[0]:
                self.gameState = player
                if won_on_diagonal[1] == 1:
                    dir = direction.DIAG_TOP_BOTTOM
                else:
                    dir = direction.DIAG_BOTTOM_TOP
                self.winningMove = ((i, j), dir)

    def playerWins_row(self, i, j, player):
        if j == 0:
            return self.grid[i][j+1] == self.grid[i][j+2] == player
        if j == 1:
            return self.grid[i][j-1] == self.grid[i][j+1] == player
        return self.grid[i][j-2] == self.grid[i][j-1] == player

    def playerWins_column(self, i, j, player):
        if i == 0:
            return self.grid[i+1][j] == self.grid[i+2][j] == player
        if i == 1:
            return self.grid[i-1][j] == self.grid[i+1][j] == player
        return self.grid[i-1][j] == self.grid[i-2][j] == player

    def playerWins_diag(self, player):
        if self.grid[0][0] == self.grid[1][1] == self.grid[2][2] == player:
            #diagonal from top right corner to bottom left corner
            return (True, 1)
        return (self.grid[0][2] == self.grid[1][1] == self.grid[2][0] == player, 2)

    def play(self, i, j, player):
        if self.gameState != 0 or len(self.moves) >= 9:
            text = "The game is a draw." if self.gameState == 0 else f"Player {self.display_player_name(self.gameState)} has won."
            raise Exception(text)
        elif self.grid[i][j] == 0:
            if len(self.moves) != 0 and self.moves[-1][2] == player:
                players_turn = 1 if player == -1 else -1
                raise Exception(f"It's player {self.display_player_name(players_turn)}'s turn to play.")
            else:
                self.grid[i][j] = player
                self.moves.append((i, j, player))
                self.updateGameState(i, j, player)
        else:
            raise Exception(f"Cell [{i}, {j}] is not empty.")
        
    def undo_last_move(self, return_move=False):
        if len(self.moves) == 0:
            return
        
        (i, j, player) = self.moves.pop()
        self.grid[i][j] = 0
        self.gameState = 0 #undoing the last move necessarily makes it so that the game cannot be over
        self.winningMove = None
        
        if return_move:
            return (i, j, player)
        
    def get_square_representation(self, i, j):
        if self.grid[i][j] == 0:
            return '*'
        if self.grid[i][j] == 1:
            return 'X'
        return 'O'
        
    def get_representation(self):
        representation = ''
        for i in range(3):
            for j in range(3):
                representation += self.get_square_representation(i, j)
        return representation
    
    def isGameOver(self):
        return self.gameState != 0 or (self.gameState == 0 and len(self.moves) >= 9)

In [7]:
def choose_node_uct(board, player, dic_nodes_visited, c):
    #dic has the board representation as key and the node as value
    
    if board.isGameOver():
        return None
    
    nb_moves_beginning = len(board.moves)
    best_uct = -float('inf')
    best_move = None
    for i in range(3):
        for j in range(3):
            try:
                board.play(i, j, player)
                
                if dic_nodes_visited.get(board.get_representation()) != None: #the node has already been visited and exists in the dictionary
                    node = dic_nodes_visited.get(board.get_representation())
                    uct = node.calculate_uct(c)
                else:
                    node = Node(board.copy())
                    dic_nodes_visited[node.board.get_representation()] = node
                    uct = node.calculate_uct(c)
                    
                if uct >= best_uct:
                    best_uct = uct
                    best_move = (i, j, node)
                
                back_to_start_state(board, nb_moves_beginning)
            except:
                back_to_start_state(board, nb_moves_beginning)
    
    back_to_start_state(board, nb_moves_beginning)      
    return best_move

def choose_node_winrate(board, player, dic_nodes_visited):    
    if board.isGameOver():
        return None
    
    nb_moves_beginning = len(board.moves)
    best_winrate = -1
    best_move = None
    for i in range(3):
        for j in range(3):
            try:
                board.play(i, j, player)
                
                if dic_nodes_visited.get(board.get_representation()) != None: #the node has already been visited and exists in the dictionary
                    node = dic_nodes_visited.get(board.get_representation())
                    winrate = node.calculate_winrate()
                else:
                    winrate = 0
                    
                if winrate >= best_winrate:
                    best_winrate = winrate
                    best_move = (i, j, node)
                
                back_to_start_state(board, nb_moves_beginning)
            except:
                back_to_start_state(board, nb_moves_beginning)
    
    back_to_start_state(board, nb_moves_beginning)      
    return best_move

def play_move_in_simulation(board, player, dic_nodes_visited, path, c):
    
    try:
        (i, j, best_node) = choose_node_uct(board, player, dic_nodes_visited, c)
        try:
            board.play(i, j, player)
            path.append((best_node, player))
            
            try:
                parent = dic_nodes_visited[board.get_representation()]
            except:
                parent = Node(board.copy())
                
            best_node.add_parent(parent)
            
        except:
            pass
    except:
        pass
    

def backpropagation(path, winner):
    for i in range(len(path)-1, -1, -1):
        node, player = path[i]
        node.update(player, winner)
            
    
def make_complete_simulation(board, player, dic_nodes_visited, c):
    path = []
    
    while not board.isGameOver():
        play_move_in_simulation(board, player, dic_nodes_visited, path, c)
        if player == -1:
            player = 1
        else:
            player = -1
    
    winner = board.gameState
    backpropagation(path, winner)
 
def play_mcts(board, player, nb_simulations, c = 2**(1/2)):
    dic_nodes_visited = {}
    
    for i in range(nb_simulations):
        copy_board = board.copy()
        make_complete_simulation(copy_board, player, dic_nodes_visited, c)
        copy_board = None
        #print(f"Simulation {i} ended -> len(dic) = {len(dic_nodes_visited)}")
        
    #print('len dic = ' + str(len(dic_nodes_visited)))
    #print('len moves = ' + str(len(board.moves)))
    
    (i, j, _) = choose_node_winrate(board, player, dic_nodes_visited)
    #board.play(i, j, player)
    return dic_nodes_visited

### Test

In [8]:
b = Board()
b.play(0, 0, 1)
b.play(1, 1, -1)
b.play(2, 2, 1)
b.play(0, 2, -1)
b.display()

 X |   | O 
-----------
   | O |   
-----------
   |   | X 


In [9]:
c = 2**(1/2)

In [22]:
d = play_mcts(b, 1, 100000, c)

In [23]:
len(d)

48

In [24]:
(i, j, node) = choose_node_winrate(b, 1, d)

In [25]:
i, j

(1, 0)

In [26]:
nb_0 = 0
for key in d:
    #print(d[key].visits)
    if d[key].visits == 0:
        nb_0 += 1

In [27]:
nb_0

22

In [42]:
b.display()

 X |   | O 
-----------
   | O |   
-----------
   |   | X 


In [43]:
nodes = []
for i in range(3):
    for j in range(3):
        try:
            copy = b.copy()
            copy.play(i, j, 1)
            node = d[copy.get_representation()]
            nodes.append(node)
            print(f"{i}, {j} -> {node.visits} visits, winrate = {node.calculate_winrate()}, uct = {node.calculate_uct(c)}")
            print(f"nb parents = {len(node.parents)}, parents' visits = {node.calculate_parents_visits()}")
            for parent in node.parents:
                parent.board.display()
            print()
        except:
            pass

0, 1 -> 1 visits, winrate = 1.0, uct = 1.0
nb parents = 1, parents' visits = 1
 X | X | O 
-----------
   | O |   
-----------
   |   | X 

1, 0 -> 1 visits, winrate = 1.0, uct = 1.0
nb parents = 1, parents' visits = 1
 X |   | O 
-----------
 X | O |   
-----------
   |   | X 

1, 2 -> 1 visits, winrate = 0.0, uct = 0.0
nb parents = 1, parents' visits = 1
 X |   | O 
-----------
   | O | X 
-----------
   |   | X 

2, 0 -> 99996 visits, winrate = 0.999989999599984, uct = 1.0151645480268796
nb parents = 1, parents' visits = 99996
 X |   | O 
-----------
   | O |   
-----------
 X |   | X 

2, 1 -> 1 visits, winrate = 0.0, uct = 0.0
nb parents = 1, parents' visits = 1
 X |   | O 
-----------
   | O |   
-----------
   | X | X 



In [38]:
nodes

[<__main__.Node at 0x193c8a30df0>,
 <__main__.Node at 0x193c8a306d0>,
 <__main__.Node at 0x193c8a30e50>,
 <__main__.Node at 0x193c9846610>,
 <__main__.Node at 0x193c9846400>]

In [40]:
for parent in nodes[0].parents:
    print(parent.get_representation())

AttributeError: 'Node' object has no attribute 'get_representation'