In [1]:
class Board:
    def __init__(self):
        self.X,self.O = [True,False]
        self.turn=self.X
        self.symbol = {self.X:'X',self.O:'O'}
        self.board = ['.' for i in range(9)]#['.','.','.','.','.','.','.','.','.']
        self.win_table = [[0,1,2],[3,4,5],[6,7,8],[0,3,6],[1,4,7],[2,5,8],[0,4,8],[2,4,6]]
        self.legal_moves = set(range(9))
        self.winner = None
    def check_winner(self,turn):
        for (i,j,k) in self.win_table:
            if self.symbol[turn] == self.board[i] == self.board[j] == self.board[k]:
                return True
        return False
    
    def push(self,move):
        if move in self.legal_moves:
            #print(board)
            self.board[move]=self.symbol[self.turn]
            #print(board)
            if self.check_winner(self.turn):
                self.legal_moves=set()
                #print("Ha ganado el jugador: {}".format(self.symbol[self.turn]))
                if self.turn == self.X:
                    self.winner = [1,0]
                else:
                    self.winner = [0,1]
            else:
                self.legal_moves.remove(move)
                if len(self.legal_moves)==0:
                    #print("Fin del juego, empate")
                    self.winner = [0,0]
            self.turn = not self.turn
    
    def __str__(self):
        cad=""
        for i in range(3):
            cad+='{} {} {}\n'.format(self.board[0+i*3],self.board[1+i*3],self.board[2+i*3])
        return cad
    
    def __repr__(self):
        return str(self)
    
    def copy(self):
        tmp = Board()
        tmp.board = self.board.copy()
        tmp.legal_moves = self.legal_moves.copy()
        tmp.turn = self.turn
        if self.winner is not None:
            tmp.winner = self.winner.copy()
        return tmp
        
        
        

In [2]:
import numpy as np
import time


class MCTS_graph:
    def __init__(self,agent):
        self.root=agent.root
        self.temperature = agent.temperature
    def make_graph(self,depth=1000):
        self.cont=0
        self.nodes = {}
        self.edges = []

        self.bfs(self.root,0,depth)
        print('Total nodes: {}'.format(self.cont))

    def bfs(self,node,father,depth):
        if depth==0: return
        if len(node.children)>0:
            total_rollouts = sum(child.num_rollouts for child in node.children)
            log_rollouts = np.log(total_rollouts)
            for n in node.children:
                self.cont+=1
                win_percentage = n.winning_frac(node.game_state.turn)
                #exploration_factor = np.sqrt(log_rollouts / n.num_rollouts)
                uct_score = win_percentage #+ self.temperature * exploration_factor
                self.nodes[self.cont]=uct_score
                self.edges.append([father,self.cont,n.move])
                self.bfs(n,self.cont,depth-1)

    def save_graph(self,path,depth=1000):
        with open(path,'w') as file:
            self.make_graph(depth)
            cad="digraph{\n  0 [label=\"root\"];\n"
            for n,m in self.nodes.items():
                cad+="  {} [label=\"{:.2f}\"];\n".format(n,m)
            for (x,y,z) in self.edges:
                cad+="  {} -- {} [label=\"{}\"];\n".format(x,y,z)
            cad+="}"
            file.write(cad)
            print("Grafo guardado en: {}".format(path))



class MCTSNode:
    def __init__(self, game_state, parent = None, move = None, bot = None, isRoot = False,scale_factor=1):
        self.game_state = game_state
        self.parent = parent
        self.move = move
        self.win_counts = np.zeros([2,])
        self.num_rollouts = 0
        self.children = []
        self.unvisited_moves = list(game_state.legal_moves)
        self.scale_factor = scale_factor
        self.isRoot=isRoot
        self.bot=bot

    def add_random_child(self,bot):
        index = np.random.randint(len(self.unvisited_moves))
        new_move = self.unvisited_moves.pop(index)#selecciona un movimiento disponible al azar y lo elimina de los movimientos no visitados
        #new_value = self.unvisited_values.pop(index)

        new_game_state = self.game_state.copy() #crea una copia del estado de juego
        
        new_game_state.push(new_move) #realiza el movimiento seleccionado
        new_node = MCTSNode(game_state=new_game_state, parent=self, move=new_move,bot=bot,scale_factor=self.scale_factor) #crea un nuevo nodo
        self.children.append(new_node) #añade el nodo a su lista de hijos
        return new_node #retorna el nuevo nodo

    def record_win(self, result):
        self.win_counts += result
        self.num_rollouts += 1
        
    def result_simulation(self):
        if self.is_terminal():
            return self.game_state.winner
        return self.bot.random_simulation(self.game_state)

    def can_add_child(self): #comprueba si aun hay nodos por visitar
        return len(self.unvisited_moves) > 0

    def is_terminal(self): #verifica si es un nodo terminal, es decir, el final de una partida
        return len(self.game_state.legal_moves)==0

    def winning_frac(self, player): #obtiene el valor Q/N para el nodo dado
        if player: #turno de las blancas
            return float(self.win_counts[0]) / float(self.num_rollouts)
        else: #turno de las negras
            return float(self.win_counts[1]) / float(self.num_rollouts)

class agent_MCTS:
    def __init__(self, temperature=2,bot=None,game_state=None,max_iter=100,verbose=0):
        self.temperature = temperature
        self.bot = bot
        self.max_iter = max_iter
        self.root = None
        self.verbose = verbose
        if game_state is not None:
            self.root = MCTSNode(game_state.copy(),bot=self.bot,isRoot=True)

    def select_move(self,board,max_iter=None,push=True):
        moves,values=self.get_move_values(board,max_iter=max_iter)
        if moves is None:
            return None
        index=np.argmax(values)
        if push:
            self.push_move(moves[index])
        return moves[index]
        
    def push_move(self,move):
        root=self.root
        for child in root.children:
            if child.move==move:
                child.isRoot=True
                self.root=child
                self.root.num_rollouts-=1
                self.root.parent=None
                return
        print("Error, movimiento no existente")
        return

    def set_max_iter(self,max_iter=100):
        self.max_iter=max_iter
        
    

    def select_child(self, node):
        """
            Selecciona un hijo usando la métrica UCT (Upper confidence bound for trees).
        """

        #Calcula N(v)
        total_rollouts = sum(child.num_rollouts for child in node.children)
        log_rollouts = np.log(total_rollouts)

        best_score = -1
        best_child = None
        #Calcula UTC(j)
        for child in node.children:
            #win_percentage = child.winning_frac(root.game_state.turn)
            win_percentage = child.winning_frac(node.game_state.turn)
            exploration_factor = np.sqrt(log_rollouts / child.num_rollouts)
            uct_score = win_percentage + self.temperature * exploration_factor
            if uct_score > best_score:
                best_score = uct_score
                best_child = child
        return best_child

    def get_move_values(self,game_state,max_iter=None):
        
        if max_iter is None:
            max_iter=self.max_iter

        if self.root is None or str(self.root.game_state)!=str(game_state):
            #print('\nEl estado de juego no corresponde con el de la raiz del arbol, se recreó la raiz')
            self.root = MCTSNode(game_state.copy(),bot=self.bot,isRoot=True)
        
        if self.root.is_terminal():
            return None,None

        root=self.root
        #print("\n")
        i=0

        tic = time.time()
        while i<max_iter:
            i+=1
            node = root
            #fase de seleccion, donde busca un nodo que no sea un nodo derminal
            while (not node.can_add_child()) and (not node.is_terminal()):
                node = self.select_child(node)

            #fase de expansión, donde se agrega un nuevo nodo
            if node.can_add_child():
                node = node.add_random_child(self.bot)
                
            #fase de simulación. Con ayuda de la red neuronal, se obtiene el valor del nodo que predice como ganador
            result = node.result_simulation()

            #fase de retropropagación, donde se actualiza el valor de Q de los nodos padres hasta llegar al nodo raiz
            while node is not None:
                node.record_win(result)
                node = node.parent
        if self.verbose>0:
            toc = time.time()-tic
            print('MCTS - rollouts:{} Elapsed time: {:.2f}s = {:.2f}m'.format(root.num_rollouts,toc,toc/60))

        
        score = np.zeros(len(root.children),)
        moves = []
        total_rollouts = sum(child.num_rollouts for child in root.children)
        if total_rollouts != root.num_rollouts:
            print("total/root {}/{}".format(total_rollouts,root.num_rollouts))
        log_rollouts = np.log(total_rollouts)
        for i,child in enumerate(root.children):
            win_percentage=child.winning_frac(root.game_state.turn)
            #exploration_factor = np.sqrt(log_rollouts / child.num_rollouts)
            score[i] = win_percentage #+ self.temperature * exploration_factor
            moves.append(child.move)
        return moves,score



In [3]:
class agent_ttt:
    def select_move(self,board):
        moves = list(board.legal_moves)
        index = np.random.randint(len(moves))
        return moves[index]
    
    def random_simulation(self,board):
        b = board.copy()
        while len(b.legal_moves)>0:
            b.push(self.select_move(b))
        return b.winner

In [4]:
board = Board()
agent = agent_ttt()
mcts = agent_MCTS(bot=agent,max_iter=1000)
move = mcts.select_move(board)
board.push(move)
board

. . .
. X .
. . .

In [5]:
board = Board()
board.push(2)
board.push(6)
board.push(3)
board.push(1)

In [6]:
board

. O X
X . .
O . .

In [7]:
from numpy.random import default_rng
rng = default_rng()

In [8]:
rng.random()

0.9484904128977524

In [14]:
bm=5
acc=0
sims=1000
for _ in range(sims):
    board = Board()
    agent = agent_ttt()
    mcts = agent_MCTS(bot=agent,max_iter=100)
    board.push(2)
    board.push(6)
    board.push(3)
    board.push(1)
    #print(board)
    move = mcts.select_move(board)
    if move==bm:
        acc+=1
    #board.push(move)
    #print(board)
acc/=sims
print("accuracy: {:.3f}".format(acc))

accuracy: 0.824


In [10]:
bm=5
acc=0
sims=1000
for _ in range(sims):
    board = Board()
    agent = agent_ttt()
    mcts = agent_MCTS(bot=agent,max_iter=1000)
    board.push(2)
    board.push(6)
    board.push(3)
    board.push(1)
    #print(board)
    move = mcts.select_move(board)
    if move==bm:
        acc+=1
    #board.push(move)
    #print(board)
acc/=sims
print("accuracy: {:.6f}".format(acc))

accuracy: 1.000000


In [13]:
bm=5
acc=0
sims=1000
for _ in range(sims):
    board = Board()
    agent = agent_ttt()
    mcts = agent_MCTS(bot=agent,max_iter=300)
    board.push(2)
    board.push(6)
    board.push(3)
    board.push(1)
    #print(board)
    move = mcts.select_move(board)
    if move==bm:
        acc+=1
    #board.push(move)
    #print(board)
acc/=sims
print("accuracy: {:.6f}".format(acc))

accuracy: 1.000000


In [11]:
bm=5
acc=0
sims=1000
for _ in range(sims):
    board = Board()
    agent = agent_ttt()
    mcts = agent_MCTS(bot=agent,max_iter=1000)
    board.push(2)
    board.push(6)
    board.push(3)
    board.push(1)
    #print(board)
    move = mcts.select_move(board)
    if move==bm:
        acc+=1
    #board.push(move)
    #print(board)
acc/=sims
print("accuracy: {:.6f}".format(acc))

KeyboardInterrupt: 

In [None]:
bm=5
acc=0
sims=1000
for _ in range(sims):
    board = Board()
    agent = agent_ttt()
    mcts = agent_MCTS(bot=agent,max_iter=1000)
    board.push(2)
    board.push(6)
    board.push(3)
    board.push(1)
    #print(board)
    move = mcts.select_move(board)
    if move==bm:
        acc+=1
    #board.push(move)
    #print(board)
acc/=sims
print("accuracy: {:.6f}".format(acc))

In [None]:
print("accuracy: {:.6f}".format(acc))

In [None]:
while len(board.legal_moves)>0:
    move = mcts.select_move(board)
    board.push(move)
    print(board)

In [None]:
stop

In [None]:
agent.random_simulation(board)

In [None]:
board = Board()
board.push(0)
board

In [None]:
mcts = agent_MCTS(bot=agent,max_iter=10)

In [None]:
move = mcts.select_move(board)
board.push(move)
board

In [None]:
mcts.root.children[1].num_rollouts

In [None]:
mcts.get_move_values(board)

In [None]:
from PIL import Image
from IPython.display import display
import networkx as nx
import pydot
from networkx.drawing.nx_pydot import graphviz_layout

In [None]:
G = MCTS_graph(mcts)
G.save_graph("mcts_ttt.dot")

In [None]:
g = nx.Graph(nx.drawing.nx_pydot.read_dot("mcts_ttt.dot"))
p=nx.drawing.nx_pydot.to_pydot(g)
p.write_png('example.png')
pil_im = Image.open('example.png', 'r')
display(pil_im)