In [42]:
import chess
from copy import copy,deepcopy
from tqdm import tqdm
import random
from random import choice
import tensorflow as tf
import numpy as np

In [71]:
#from mcts import *

import time
import math
import random


def randomPolicy(state):
    state_copy = State(state.board.copy(),state.player,not state.turn)
    #time1 = time.time()
    while not state_copy.isTerminal():
        try:
            action = random.choice(state_copy.getPossibleActions())
        except IndexError:
            raise Exception("Non-terminal state has no possible actions: " + str(state))
        state_copy.board.push(action[0])
    #print((time.time() - time1))
    return state.getReward()


class treeNode():
    def __init__(self, state, parent):
        self.state = state
        self.isTerminal = state.isTerminal()
        self.isFullyExpanded = self.isTerminal
        self.parent = parent
        self.numVisits = 0
        self.totalReward = 0
        self.children = {}


class mcts():
    def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
                 rolloutPolicy=randomPolicy):
        if timeLimit != None:
            if iterationLimit != None:
                raise ValueError("Cannot have both a time limit and an iteration limit")
            # time taken for each MCTS search in milliseconds
            self.timeLimit = timeLimit
            self.limitType = 'time'
        else:
            if iterationLimit == None:
                raise ValueError("Must have either a time limit or an iteration limit")
            # number of iterations of the search
            if iterationLimit < 1:
                raise ValueError("Iteration limit must be greater than one")
            self.searchLimit = iterationLimit
            self.limitType = 'iterations'
        self.explorationConstant = explorationConstant
        self.rollout = rolloutPolicy

    def search(self, initialState):
        self.root = treeNode(initialState, None)

        if self.limitType == 'time':
            timeLimit = time.time() + self.timeLimit / 1000
            count = 0
            while time.time() < timeLimit:
                count = count + 1
                self.executeRound()
            print("Rollouts : ",count)
        else:
            for i in range(self.searchLimit):
                self.executeRound()

        bestChild = self.getBestChild(self.root, 0)
        return self.getAction(self.root, bestChild)

    def executeRound(self):
        node = self.selectNode(self.root)
        reward = self.rollout(node.state)
        self.backpropogate(node, reward)

    def selectNode(self, node):
        while not node.isTerminal:
            if node.isFullyExpanded:
                node = self.getBestChild(node, self.explorationConstant)
            else:
                return self.expand(node)
        return node

    def expand(self, node):
        actions = node.state.getPossibleActions()
        for action in actions:
            if action not in node.children:
                newNode = treeNode(node.state.takeAction(action), node)
                node.children[action] = newNode
                if len(actions) == len(node.children):
                    node.isFullyExpanded = True
                return newNode

        raise Exception("Should never reach here")

    def backpropogate(self, node, reward):
        while node is not None:
            node.numVisits += 1
            node.totalReward += reward
            node = node.parent

    def getBestChild(self, node, explorationValue):
        bestValue = float("-inf")
        bestNodes = []
        for child in node.children.values():
            nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt(
                2 * math.log(node.numVisits) / child.numVisits)
            if nodeValue > bestValue:
                bestValue = nodeValue
                bestNodes = [child]
            elif nodeValue == bestValue:
                bestNodes.append(child)
        return random.choice(bestNodes)

    def getAction(self, root, bestChild):
        for action, node in root.children.items():
            if node is bestChild:
                return action

In [72]:
temp_board = chess.Board()
start_time = time.time()
for i in range(10000):
    temp_board = chess.Board()
    z = list(temp_board.legal_moves)
    x = []
    for j in z:
        x.append(j)
    new_state = State(temp_board.copy(),1,True)
    new_state.board.push(z[0])
temp_board
print((time.time() - start_time))

0.7201623916625977


In [73]:
# lets see how much time it takes to do a random rollout
temp_board = chess.Board()
start_time = time.time()
done = False
while not done:
    temp_list = list(temp_board.legal_moves)
    temp_board.push(choice(temp_list))
    if temp_board.result() != "*":
        done = True
print((time.time() - start_time))

0.05001020431518555


In [74]:
class State():
    def __init__(self,board,player,turn):
        self.board = board
        self.player = player # 1 player (white), 2 player (black) (only used for rewards)
        self.turn = turn # true white, false black
    def getCurrentPlayer(self):
        return self.player
    def getPossibleActions(self):
        # making action to be a tuple for stupid reasons 

        temp = list(self.board.legal_moves)
        output = []
        for i in temp:
            output.append((i,0))
        return output

    def takeAction(self,action):
        new_state = State(self.board.copy(),self.player,not self.turn)
        new_state.board.push(action[0])
        return new_state

    def isTerminal(self):
        if self.board.result() == "*":
            return False
        else:
            return True

    def getReward(self):
        if self.board.result() == "*":
            return 0
        if self.board.result() == "1/2-1/2":
            return 0.5
        if self.board.result() == "1-0" and self.player == 1:
            return 1
        elif self.board.result() == "1-0" and self.player == 2:
            return -1
        if self.board.result() == "0-1" and self.player == 2:
            return 1
        elif self.board.result() == "0-1" and self.player == 1:
            return -1


In [75]:
temp_board = chess.Board()
initialState = State(temp_board,1,True)
tree = mcts(timeLimit=2000)
action = tree.search(initialState=initialState)

Rollouts :  64


In [76]:
import chess
import chess.svg
from IPython.display import SVG, display
board = chess.Board()


## SELF PLAY LOOP

In [80]:
for i in range(1):
    board = chess.Board()
    done = False
    turn = True # white's turn
    while not done:
        if turn:
            initialState = State(board,1,True)
            tree = mcts(timeLimit=10000)
            action = tree.search(initialState=initialState)
            board.push(action[0])
            turn = False
            print(board)
            print()
        else:
            initialState = State(board,2,False)
            tree = mcts(timeLimit=10000)
            action = tree.search(initialState=initialState)
            board.push(action[0])
            turn = True
            print(board)
            print()
        if board.result() != "*":
            done = True
            print(board.result())


Rollouts :  312
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . P . .
P P P P P . P P
R N B Q K B N R

Rollouts :  315
r . b q k b n r
p p p p p p p p
n . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . P . .
P P P P P . P P
R N B Q K B N R

Rollouts :  315
r . b q k b n r
p p p p p p p p
n . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . P . N
P P P P P . P P
R N B Q K B . R

Rollouts :  310
r . b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. n . . . . . .
. . . . . P . N
P P P P P . P P
R N B Q K B . R

Rollouts :  318
r . b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. n P . . . . .
. . . . . P . N
P P . P P . P P
R N B Q K B . R

Rollouts :  305
r . b q k b n r
p p p p p p . p
. . . . . . . .
. . . . . . p .
. n P . . . . .
. . . . . P . N
P P . P P . P P
R N B Q K B . R

Rollouts :  316
r . b q k b n r
p p p p p p . p
. . . . . . . .
. . . . . . p .
. n P . . . . .
P . . . . P . N
. P . P P . P P
R 