In [1]:
!sudo apt install graphviz

Reading package lists... Done
Building dependency tree       
Reading state information... Done
graphviz is already the newest version (2.40.1-2).
0 upgraded, 0 newly installed, 0 to remove and 11 not upgraded.


In [2]:
import numpy as np
import random
import time
import copy

In [3]:
'''
    Tic tac toe game
'''
EMPTY_STR = '-'
BLACK_STR = 'B'
WHITE_STR = 'W'

'''
    Tic Tac Toe game state
'''
class TicTacToeState(object):
    def __init__(self, curr_turn = None, board = None):
        if board is None:
            self.board = np.empty((3, 3), dtype = str)
            self.board.fill(EMPTY_STR)
        else:
            self.board = board

        if curr_turn is None:
            self.curr_turn = random.choice([BLACK_STR, WHITE_STR])
        else:
            self.curr_turn = curr_turn

        self.win_color, self.terminate = self.game_result()

    def get_legal_actions(self):
        return [(i, j) for i, j in zip(*np.where(self.board == EMPTY_STR))]

    def game_result(self):
        for color in [BLACK_STR, WHITE_STR]:
            positions = self.board == color
            if positions.all(axis = 0).any() or positions.all(axis = 1).any() or np.diagonal(positions).all() or np.diagonal(np.rot90(positions)).all():
                return color, True

        if len(self.get_legal_actions()) == 0:
            return None, True

        return None, False

    def move(self, pos):
        assert self.board[pos] == EMPTY_STR

        new_board = np.copy(self.board)
        new_board[pos] = self.curr_turn

        next_turn = BLACK_STR if self.curr_turn == WHITE_STR else WHITE_STR

        return TicTacToeState(next_turn, new_board)

    def random_move(self):
        return random.choice(self.get_legal_actions())

    def show(self):
        # print('\n'.join(''.join(row) for row in self.board))
        return '\n'.join(''.join(row) for row in self.board)

In [4]:
'''
    Monte Carlo Tree Search
'''
class MCTSNode(object):
    def __init__(self, state, parent = None):
        self.state = state
        self.parent = parent
        self.untried_actions = self.state.get_legal_actions()
        self.children = []
        self.N = 0 # Number of simulations
        self.Q = {BLACK_STR : 0, WHITE_STR : 0} # Number of wins

    def random_child(self):
        return random.choice(self.children)

    def best_child(self, constant = 2**0.5):
        # UCT = Q / child.N + sqrt(ln(parent.N) / child.N)
        # Where Q here considers both win and loss.
        ucts = [(self.Q[self.state.curr_turn] - self.Q[BLACK_STR if self.state.curr_turn == WHITE_STR else WHITE_STR]) / c.N + constant * (np.log(self.N) / c.N) ** 0.5 for c in self.children]
        return self.children[np.argmax(ucts)]

    def is_leaf(self):
        return self.state.terminate

    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

    def backpropagate(self, win_color):
        self.N += 1
        if win_color:
            self.Q[win_color] += 1
        
        if self.parent:
            # Recursively backpropagate
            self.parent.backpropagate(win_color)

    def expand(self):
        action = self.untried_actions.pop()
        next_state = self.state.move(action)
        child = MCTSNode(next_state, self)
        self.children.append(child)
        return child

    def rollout(self):
        curr_state = self.state
        while not curr_state.terminate:
            pos = curr_state.random_move()
            curr_state = curr_state.move(pos)
        return curr_state.win_color

    def traverse(self):
        if self.is_leaf():
            # Reach a leaf node (the game terminates).
            return self
        else:
            if not self.is_fully_expanded():
                # If the current node is not fully expanded, then expand it.
                return self.expand()
            else:
                # If the current node is fully expanded, choose the optimal child.
                return self.best_child().traverse()

In [5]:
def draw_graph(root, path):
    tree_str = []
    total_num = 0
    visitset = {}
    visitset[total_num] = root

    while len(visitset) > 0:
        curr_idx = min(visitset.keys())
        curr_node = visitset[curr_idx]
        visitset.pop(curr_idx)

        for child in curr_node.children:
            total_num += 1
            tree_str.append("%d[label=\"%s\"];\n%d->%d;" 
                            % (total_num, child.state.show() + '\nN=%d\nQB=%d\nQW=%d\n' % (child.N, child.Q[BLACK_STR], child.Q[WHITE_STR]), curr_idx, total_num))
            visitset[total_num] = child
    
    with open(path, 'w') as f:
        f.write("digraph mcts{")
        f.writelines(tree_str)
        f.write('}')

In [6]:
if __name__ == '__main__':
    game = TicTacToeState()
    root = MCTSNode(game)

    for _ in range(10000):
        # 1. Traverse from the root until a leaf node is reached.
        target = root.traverse()
        # 2. From the leaf, start simulation.
        simulation_res = target.rollout()
        # 3. Backpropagate with the result of the simulation.
        target.backpropagate(simulation_res)
    
    draw_graph(root, 'mcts.dot')

In [7]:
!dot -Kdot -Tpng mcts.dot -o mcts.png

dot: graph is too large for cairo-renderer bitmaps. Scaling by 0.103534 to fit


In [8]:
move = []
curr = root
while not curr.is_leaf():
    curr = curr.best_child(constant = 0)
    move.append(curr.state.show())

for m in move:
    print(m)
    print('-------------')

---
---
-W-
-------------
---
---
-WB
-------------
---
W--
-WB
-------------
---
W--
BWB
-------------
-W-
W--
BWB
-------------
-W-
W-B
BWB
-------------
WW-
W-B
BWB
-------------
WW-
WBB
BWB
-------------
WWW
WBB
BWB
-------------
