Import necessary libraries

In [9]:

import chess
import chess.pgn
import chess.engine
import random
import time
from math import log,sqrt,e,inf

Create MCTS node class for holding board position, action taken to get to position, children nodes, parent node, N for number of times parent node has been visited, n for number of times child node has been visited and V for score of current node. 
Formula to be used for explore/exploit:
UCB = V + 2*sqrt(ln N/n)

In [10]:
class node():
    def __init__(self):
        self.state = chess.Board()
        self.action = ''
        self.children = set()
        self.parent = None
        self.N = 0
        self.n = 0
        self.v = 0

Function for calculation of UCB:

In [12]:
def ucb1(curr_node):
    ans = curr_node.v+2*(sqrt(log(curr_node.N+e+(10**-6))/(curr_node.n+(10**-10))))
    return ans

Function for expanding nodes before random playout

In [14]:
def expand(curr_node,white):
    if(len(curr_node.children)==0):
        return curr_node
    max_ucb = -inf
    if(white):
        idx = -1
        max_ucb = -inf
        sel_child = None
        for i in curr_node.children:
            tmp = ucb1(i)
            if(tmp>max_ucb):
                idx = i
                max_ucb = tmp
                sel_child = i

        return(expand(sel_child,0))

    else:
        idx = -1
        min_ucb = inf
        sel_child = None
        for i in curr_node.children:
            tmp = ucb1(i)
            if(tmp<min_ucb):
                idx = i
                min_ucb = tmp
                sel_child = i

        return expand(sel_child,1)



Simulate play function (Rollout in MCTS)

In [13]:
def rollout(curr_node):
    
    if(curr_node.state.is_game_over()):
        board = curr_node.state
        if(board.result()=='1-0'):
            #print("h1")
            return (1,curr_node)
        elif(board.result()=='0-1'):
            #print("h2")
            return (-1,curr_node)
        else:
            return (0.5,curr_node)
    
    all_moves = [curr_node.state.san(i) for i in list(curr_node.state.legal_moves)]
    
    for i in all_moves:
        tmp_state = chess.Board(curr_node.state.fen())
        tmp_state.push_san(i)
        child = node()
        child.state = tmp_state
        child.parent = curr_node
        curr_node.children.add(child)
    rnd_state = random.choice(list(curr_node.children))

    return rollout(rnd_state)

Rollback function for backpropagation of new info for UCB calculation

In [15]:
def rollback(curr_node,reward):
    curr_node.n+=1
    curr_node.v+=reward
    while(curr_node.parent!=None):
        curr_node.N+=1
        curr_node = curr_node.parent
    return curr_node

# Main function for use


1. Checks if game is over and returns -1 if it is
2. Creates map_state_move dict with all move and nodes from current node
3. Creates loop to iterate for input amount of iterations
4. If white maximize, if black minimize
5. Find min/max UCB in children, then expand, rollout, rollback.
6. After iterations, find and returned min/max UCB depending on color

In [21]:
def mcts_pred(curr_node,over,white,iterations=10):
    if(over):
        return -1
    all_moves = [curr_node.state.san(i) for i in list(curr_node.state.legal_moves)]
    map_state_move = dict()
    
    for i in all_moves:
        tmp_state = chess.Board(curr_node.state.fen())
        tmp_state.push_san(i)
        child = node()
        child.state = tmp_state
        child.parent = curr_node
        curr_node.children.add(child)
        map_state_move[child] = i
        
    while(iterations>0):
        if(white):
            idx = -1
            max_ucb = -inf
            sel_child = None
            for i in curr_node.children:
                tmp = ucb1(i)
                if(tmp>max_ucb):
                    idx = i
                    max_ucb = tmp
                    sel_child = i
            ex_child = expand(sel_child,0)
            reward,state = rollout(ex_child)
            curr_node = rollback(state,reward)
            iterations-=1
        else:
            idx = -1
            min_ucb = inf
            sel_child = None
            for i in curr_node.children:
                tmp = ucb1(i)
                if(tmp<min_ucb):
                    idx = i
                    min_ucb = tmp
                    sel_child = i

            ex_child = expand(sel_child,1)

            reward,state = rollout(ex_child)

            curr_node = rollback(state,reward)
            iterations-=1
    if(white):
        
        mx = -inf
        idx = -1
        selected_move = ''
        for i in (curr_node.children):
            tmp = ucb1(i)
            if(tmp>mx):
                mx = tmp
                selected_move = map_state_move[i]
        return selected_move
    else:
        mn = inf
        idx = -1
        selected_move = ''
        for i in (curr_node.children):
            tmp = ucb1(i)
            if(tmp<mn):
                mn = tmp
                selected_move = map_state_move[i]
        return selected_move

# How to Use

In [24]:

board = chess.Board()
root = node()
root.state = board

is_white_to_move = board.turn


best_move_san = mcts_pred(root, board.is_game_over(), is_white_to_move, iterations=10)

print("Best Move:", best_move_san)

Best Move: h4


# How long does it take?

In [27]:
import time

board = chess.Board()
root = node()
root.state = board

is_white_to_move = board.turn

start = time.time()
best_move_san = mcts_pred(root, board.is_game_over(), is_white_to_move, iterations=10)
end = time.time()

duration = end - start

print(f"10 iterations took {duration} seconds to run.")


board = chess.Board()
root = node()
root.state = board

is_white_to_move = board.turn

start = time.time()
best_move_san = mcts_pred(root, board.is_game_over(), is_white_to_move, iterations=15)
end = time.time()

duration = end - start

print(f"15 iterations took {duration} seconds to run.")

board = chess.Board()
root = node()
root.state = board

is_white_to_move = board.turn

start = time.time()
best_move_san = mcts_pred(root, board.is_game_over(), is_white_to_move, iterations=20)
end = time.time()

duration = end - start

print(f"20 iterations took {duration} seconds to run.")

board = chess.Board()
root = node()
root.state = board

is_white_to_move = board.turn

start = time.time()
best_move_san = mcts_pred(root, board.is_game_over(), is_white_to_move, iterations=30)
end = time.time()

duration = end - start

print(f"30 iterations took {duration} seconds to run.")

10 iterations took 7.329426050186157 seconds to run.
15 iterations took 12.141134977340698 seconds to run.
20 iterations took 16.26421594619751 seconds to run.
30 iterations took 24.37513566017151 seconds to run.
