In [15]:
# MCST algorithm

# reference: https://int8.io/monte-carlo-tree-search-beginners-guide/
from tictactoe import *
import math
import random   

UBC_C = 1.41 # recommended value for UBC_C is sqrt(2)

class Node():
    def __init__(self, game:TicTacToe, parent=None):
        self.game = game
        self.parent = parent
        self.child_count = sum([1 for x in game.board if x == EMPTY])
        self.player = COMPUTER if parent and parent.player == PLAYER else PLAYER

        self.children = [] # visited
        self.visits = 0
        self.score = 0
        self.is_expanded = False

    def copy(self):
        new_node = Node(self.game.copy())
        new_node.children = self.children.copy()
        new_node.visits = self.visits
        new_node.score = self.score
        new_node.is_expanded = self.is_expanded
        new_node.player = self.player
        return new_node

def rollout_policy(node: Node):
    # pick the first empty cell and play
    possible_moves = []
    for i in range(node.game.size):
        for j in range(node.game.size):
            if node.game.board[i * node.game.size + j] == EMPTY:
                possible_moves.append(i * node.game.size + j)

    if len(possible_moves) < 0:
        return node
    
    position = random.choice(possible_moves)

    # print("select position:%d" % position)
    node.game.set_move(position//node.game.size, position%node.game.size, 3 - node.player)
    # update the player
    # print("update player: %d to %d" %(node.player, 3- node.player))
    node.player = 3 - node.player

    return node

def terminal(node: Node):
    if is_winner(node.game, node.player) or \
        is_winner(node.game, 3 - node.player):
        return True
    if is_terminal(node.game):
        return True
    return False

def result(node: Node):
    if is_winner(node.game,COMPUTER):
        return 1
    if is_winner(node.game,PLAYER):
        return -1
    return 0

def rollout(node : Node, callback= None):
    # when similating the game, we don't save the simulated game.
    tnode = node.copy()
    while not terminal(tnode):
        tnode = rollout_policy(tnode)
        # for debug 
        if callback is not None:
            callback(tnode)
        # for debug
            
    # reutrn the score of the game
    return result(tnode)

def fully_expanded(node: Node):
    return node.child_count == len(node.children)

def best_ucb(node: Node, c = UBC_C):
    best_score = -1
    best_node = None
    for child in node.children:
        if child.visits == 0:
            return child
        score = child.score / child.visits + \
            c * math.sqrt(math.log(node.visits) / child.visits)
        if score > best_score:
            best_score = score
            best_node = child
    return best_node
       
def pick_unvisited(node: Node):
    # pick a unvisited children. not in childeren list
    for i in range(node.game.size):
        for j in range(node.game.size):
            if node.game.board[i * node.game.size + j] == EMPTY:
                new_game = node.game.copy()
                new_game.set_move(i, j, COMPUTER if node.player == PLAYER else PLAYER)
                
                # an un visited game
                is_unvisited = True 
                for child in node.children:
                    if child.game.board == new_game.board:
                        is_unvisited = False

                if is_unvisited == False:
                    continue
                         
                # print("un visted kid")
                new_node = Node(new_game, node)
                node.children.append(new_node)
                return new_node
    return None

def traverse(node: Node):
    # while fully_expanded(node):
    #     node = best_ucb(node)
    # if not terminal(node):
    #     pnode = pick_unvisited(node)
    #     if (pnode is not None):
    #         return pnode
    while not terminal(node):
        if not fully_expanded(node):    
            return pick_unvisited(node)
        else:
            node = best_ucb(node)
    return node # in case no children are present / node is terminal

def backpropagate(node: Node, result):
    while node is not None:
        node.visits += 1
        node.score += result
        node = node.parent

def monte_carlo_tree_search_max_iteration(iterations, root: Node):
    while iterations:
        leaf = traverse(root) # leaf = unvisited node 
        if fully_expanded(leaf) is True:
            break
        simulation_result = rollout(leaf)
        backpropagate(leaf, simulation_result)
        iterations -=1
    return best_ucb(root, 0)


In [16]:
import math
def mcst_policy(game):
    root = Node(game.copy()) 
    print("root palyer:", root.player)
    best_node = monte_carlo_tree_search_max_iteration(1000, root)
    print("best node player:", best_node.player)
    print("computer choice:",best_node.game.board)
    print("visted:", best_node.visits)
    for i in range(game.size * game.size):
        if game.board[i] != best_node.game.board[i]:
            return i
    return -1

# # play the game with mcst policy
game_loop(mcst_policy, 3)

start, input 'i j' to select the place. from 0 to 2:
— — — 
— — — 
— — — 
player 1 move:
X — — 
— — — 
— — — 
root palyer: 1
best node player: 2
computer choice: [1, 0, 2, 0, 0, 0, 0, 0, 0]
visted: 318
computer move:
X — O 
— — — 
— — — 
player 1 move:
X — O 
— — — 
— — X 
root palyer: 1
best node player: 2
computer choice: [1, 0, 2, 0, 2, 0, 0, 0, 1]
visted: 244
computer move:
X — O 
— O — 
— — X 


ValueError: not enough values to unpack (expected 2, got 0)

In [None]:
# test rollout_policy
game = TicTacToe()
n = Node(game)
n = rollout_policy(n)
# print_board(n.game)
# print("----")
assert n.game.board[0]==2, "fail rollout_policy 1"

n = rollout_policy(n)
# print_board(n.game)
# print("----")
assert n.game.board[1]==1, "fail rollout_policy 2"

n = rollout_policy(n)
# print_board(n.game)
# print("----")
assert n.game.board[2]==2, "fail rollout_policy 3"

# test terminate
game = TicTacToe()
game.board=[1,2,1, 
            2,1,2,
            1,0,0,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 1"
assert result(n) == -1, "fail result 1"

game.board=[1,2,1, 
            1,2,2,
            1,0,0,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 2"
assert result(n) == -1, "fail result 2"

game.board=[1,1,1, 
            2,0,2,
            1,0,0,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 3"
assert result(n) == -1, "fail result 3"

game.board=[1,0,2, 
            2,1,2,
            1,0,1,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 4"
assert result(n) == -1, "fail result 4"

game.board=[2,0,1, 
            2,1,2,
            1,0,1,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 5"
assert result(n) == -1, "fail result 5"

game.board=[2,0,1, 
            2,1,2,
            2,0,1,
           ]
n = Node(game)
assert terminal(n)==True, "fail terminal 6"
assert result(n) == 1, "fail result 5"

In [None]:
# test rollout
def print_rollout_middle_stage(n: Node):
  print("--------")
  print_board(n.game)
  pass

game = TicTacToe()
n = Node(game)

print("test rollout 1")
assert rollout(n, print_rollout_middle_stage) == 1, "fail rollout 1"

game = TicTacToe()
game.board=[1,0,0, 0,0,0, 0,0,0]
n = Node(game)

print("test rollout 2")
assert rollout(n, print_rollout_middle_stage) == -1, "fail rollout 2"

game = TicTacToe()
game.board=[1,0,0, 0,2,0, 0,0,1]
n = Node(game)

print("test rollout 3")
assert rollout(n, print_rollout_middle_stage) == -1, "fail rollout 3"


game = TicTacToe()
game.board=[1,2,1, 0,2,2, 2,1,1]
n = Node(game)
n.player=COMPUTER
print("PLAYER:",n.player)

print("test rollout 4")
assert rollout(n, print_rollout_middle_stage) == 0, "fail rollout 4"

test rollout 1
select position:0
update player: 1 to 2
--------
O — — 
— — — 
— — — 
select position:1
update player: 2 to 1
--------
O X — 
— — — 
— — — 
select position:2
update player: 1 to 2
--------
O X O 
— — — 
— — — 
select position:3
update player: 2 to 1
--------
O X O 
X — — 
— — — 
select position:4
update player: 1 to 2
--------
O X O 
X O — 
— — — 
select position:5
update player: 2 to 1
--------
O X O 
X O X 
— — — 
select position:6
update player: 1 to 2
--------
O X O 
X O X 
O — — 
test rollout 2
select position:1
update player: 1 to 2
--------
X O — 
— — — 
— — — 
select position:2
update player: 2 to 1
--------
X O X 
— — — 
— — — 
select position:3
update player: 1 to 2
--------
X O X 
O — — 
— — — 
select position:4
update player: 2 to 1
--------
X O X 
O X — 
— — — 
select position:5
update player: 1 to 2
--------
X O X 
O X O 
— — — 
select position:6
update player: 2 to 1
--------
X O X 
O X O 
X — — 
test rollout 3
select position:1
update player: 1 to 2
-----

In [None]:
# test fully_expanded
game = TicTacToe()
n = Node(game)
assert fully_expanded(n) == False, "fail fully_expanded 1"

game = TicTacToe()
n = Node(game)
for i in range(n.game.size):
  for j in range(n.game.size):
    if n.game.board[i*n.game.size+j] == EMPTY:
      # save one to children
      new_game = n.game.copy()
      new_game.set_move(i, j, COMPUTER if n.player == PLAYER else PLAYER)
      new_node = Node(new_game, n)
      n.children.append(new_node)
      break
  break
assert len(n.children) == 1, "fail add children 1"
assert fully_expanded(n) == False, "fail fully_expanded 2"

game = TicTacToe()
n = Node(game)
for i in range(n.game.size):
  for j in range(n.game.size):
    if n.game.board[i*n.game.size+j] == EMPTY:
      # save one to children
      new_game = n.game.copy()
      new_game.set_move(i, j, COMPUTER if n.player == PLAYER else PLAYER)
      new_node = Node(new_game, n)
      n.children.append(new_node)
      break
assert len(n.children) == 3, "fail add children 2"
assert fully_expanded(n) == False, "fail fully_expanded 3"



game = TicTacToe()
n = Node(game)
for i in range(n.game.size):
  for j in range(n.game.size):
    if n.game.board[i*n.game.size+j]  == EMPTY:
      # save one to children
      new_game = n.game.copy()
      new_game.set_move(i, j, COMPUTER if n.player == PLAYER else PLAYER)
      new_node = Node(new_game, n)
      n.children.append(new_node)

assert len(n.children) == 9, "fail add children 3"
assert fully_expanded(n) == True, "fail fully_expanded 4"




In [None]:
# test pick_unvisited
game = TicTacToe()
n = Node(game)

unvisted_n = pick_unvisited(n)
assert n.player != unvisted_n.player, "fail pick_unvisited 1 player"
assert n.children[0].game.board == unvisted_n.game.board, "fail pick_unvisited 1 board"
assert n.children[0].game.board[0] == 2, "fail pick_unvisited 1 value"

unvisted_n = pick_unvisited(n)
assert n.player != unvisted_n.player, "fail pick_unvisited 2 player"
assert n.children[1].game.board == unvisted_n.game.board, "fail pick_unvisited 2 board"
assert n.children[1].game.board[1] == 2, "fail pick_unvisited 2 value"
assert n.children[1].game.board[0] == 0, "fail pick_unvisited 2 value"



In [None]:
import random
class MCTSNode():
  def __init__(self, game:TicTacToe, parent=None):
      self.visits = 0
      self.score = 0
      self.game = game
      self.parent = parent
      self.children = []

  def untried_moves(self):
    return [i for i in range(self.game.size * self.game.size) if self.game.board[i] == EMPTY]
  
  def expand(self):
    move = self.untried_moves()[0]
    self.untried_moves.remove(move)

    new_game = self.game.copy()
    new_game.set_move(move//self.game.size, move%self.game.size, COMPUTER if self.parent.game.player == PLAYER else PLAYER)
    new_node = MCTSNode(new_game, self)
    self.children.append(new_node)
    return new_node
  
  def is_terminal(self):
    return is_winner(self.game, COMPUTER) or is_winner(self.game, PLAYER) or is_terminal(self.game)
  
  def rollout(self):
    tnode = self.copy()
    while not is_terminal(tnode.game):
        tnode = tnode.expand()
        possible_moves = [i for i in range(tnode.game.size * tnode.game.size) if tnode.game.board[i] == EMPTY]
        move = random.choice(possible_moves)
        tnode.game.set_move(move//tnode.game.size, move%tnode.game.size, COMPUTER if tnode.game.player == PLAYER else PLAYER)
    return result(tnode.game)  
  
  def backpropagate(self, result):
    self.visits += 1
    self.score += 1
    if self.parent:
      self.parent.backpropagate(result)
      

  