#MCTS Tic Tac Toe

In [None]:
from copy import deepcopy
from math import log, sqrt
from random import choice as rndchoice
import time
import random
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class TTT:
    def __init__(self):
        
        self.board = list('_' * 9)
        self.result = 0
        self.player = 1

    def render(self):
        disp_board = [f' {self.board[i]} ' if self.board[i] != '_' else f'({str(i)})' for i in range(9)]
        for i in range(3):
            print(' '.join(disp_board[3 * i:3 * (i + 1)]))
        print('----------------------------------------------------------')

    def checkresult(self):
        winning_cases = [(0, 1, 2), (3, 4, 5), (6, 7, 8),
                         (0, 3, 6), (1, 4, 7), (2, 5, 8),
                         (0, 4, 8), (2, 4, 6)]

        for wc in winning_cases:
            if self.board[wc[0]] != '_' and \
                    self.board[wc[0]] == self.board[wc[1]] and \
                    self.board[wc[1]] == self.board[wc[2]]:
                if self.board[wc[0]] == 'X':
                    self.result = 1
                else:
                    self.result = 2
        if '_' not in self.board:
            self.result = 3

    def checkresult_with_returns(self, s=None):
        s = self.board if s is None else s

        winning_cases = [(0, 1, 2), (3, 4, 5), (6, 7, 8),
                         (0, 3, 6), (1, 4, 7), (2, 5, 8),
                         (0, 4, 8), (2, 4, 6)]
        for wc in winning_cases:
            if s[wc[0]] != '_' and \
                    s[wc[0]] == s[wc[1]] and \
                    s[wc[1]] == s[wc[2]]:
                if s[wc[0]] == 'O':
                    return 1
                else:
                    return 2
        if '_' not in self.board:
            return 3
        else:
            return 0

    @staticmethod
    def player_marker(p):
        if p == 1:
            return 'X'
        else:
            return 'O'

    def player_input(self):
        marker = TTT.player_marker(self.player)
        input_msg = f'Player {self.player}({marker}), select your next position\n'

        while 1:
            try:
                v = int(input(input_msg))
                if self.board[v] != '_':
                    raise NotImplementedError
                else:
                    break
            except NotImplementedError:
                print("You can't place marker on that place")

        self.board[v] = marker
        print()

    def ai_input(self, v):
        marker = TTT.player_marker(self.player)
        self.board[v] = marker

    def switch_player(self):
        self.player = 3 - self.player

    def play(self):
        while self.result == 0:
            self.render()
            self.player_input()
            self.checkresult()
            self.switch_player()
        self.render()
        if self.result == 3:
            print('Game ended in DRAW')
        else:
            print(f'Player {self.result}({TTT.player_marker(self.result)}) Won!')

    def empty_spots(self):
        empty=[]
        for i in range(len(self.board)):
            if self.board[i]=='_' or self.board[i]==0:
                empty.append(i)
        return np.array(empty)
    
    def move(self,act,turn):
        self.board[act]=self.player_marker(turn)

In [None]:
winning_cases = [(0, 1, 2), (3, 4, 5), (6, 7, 8),
                 (0, 3, 6), (1, 4, 7), (2, 5, 8),
                 (0, 4, 8), (2, 4, 6)]

class Node:
    def __init__(self, s, par_node=None, pre_action=None):
        self.parent = par_node
        self.child = []
        self.q = 0
        self.n = 0
        self.pre_action = pre_action
        self.state = s
        self.player = MCTS.current_player(s)
        self.utc = float('inf')
        self.result = MCTS.is_terminal(s)

    def __repr__(self):
        ratio = self.q / (self.n + 1)
        l = [str(e) for e in (self.pre_action, ''.join(self.state), self.q, self.n, str(ratio)[:5], str(self.utc)[:5])]
        return ' '.join(l)

    def update(self, v):
        self.n += 1
        if v == 3:
            self.q += 0.5
        elif v == 3 - self.player:
            self.q += 1


In [None]:

class MCTS:
    def __init__(self, s):
        self.root = Node(s)
        self._expansion(self.root)

    def mcts(self, mode='iteration', criteria=10000, new_board=None):
        if new_board is not None:
            self.__init__(new_board)
        if mode == 'iter':
            for _ in range(criteria):
                self._mcts_loop()
            return criteria
        elif mode == 'time':
            start_time = time.time()
            time_criteria = criteria/1000
            iii = 0
            while time.time() - start_time < time_criteria:
                self._mcts_loop()
                iii += 1
            return iii
        else:
            raise NotImplementedError

    def result_view(self):
        best_node, best_visits = None, 0
        for n in self.root.child:
            if n.n > best_visits: 
                best_visits, best_node = n.n, n
        return int(best_node.pre_action)

    def result_return(self):
        best_node, best_visits = None, 0
        for n in self.root.child:
            if n.n > best_visits: best_visits, best_node = n.n, n
        return best_node.pre_action

    def _mcts_loop(self):
        node = self._selection(self.root)
        self._expansion(node)

        if node.child:
            selected_node = rndchoice(node.child)
        else:
            selected_node = node

        v = self._simulation(deepcopy(selected_node.state))
        self._backpropagation(selected_node, v)

    def _selection(self, node):
        if node.child:
            imax, vmax = 0, 0
            for i, n in enumerate(node.child):
                n.utc = MCTS.utc(n)
                v = n.utc
                if v > vmax:
                    imax, vmax = i, v
            selected = node.child[imax]
            return self._selection(selected)
        else:
            selected = node
            return selected

    def _expansion(self, node):
        if self.is_terminal(node.state) == 0:
            actions = self.actions_available(node.state)
            for a in actions:
                state_after_action = self.action_result(node.state, a)
                node.child.append(Node(state_after_action, node, a))

    def _simulation(self, s):
        if self.is_terminal(s) == 0:
            actions = self.actions_available(s)
            a = rndchoice(actions)
            s = self.action_result(s, a)
            return self._simulation(s)
        else:
            return self.is_terminal(s)

    def _backpropagation(self, node, v):
        node.update(v)
        if node.parent:
            self._backpropagation(node.parent, v)

    @staticmethod
    def is_terminal(s):
        for wc in winning_cases:
            if s[wc[0]] != '_' and \
                    s[wc[0]] == s[wc[1]] and \
                    s[wc[1]] == s[wc[2]]:
                if s[wc[0]] == 'X':
                    return 1
                else:
                    return 2
        if '_' not in s:
            return 3
        else:
            return 0

    @staticmethod
    def actions_available(s):
        l = []
        for i in range(9):
            if s[i] == '_': l.append(i)
        return l

    @staticmethod
    def action_result(s, a):
        p = MCTS.current_player(s)
        new_s = deepcopy(s)
        new_s[a] = 'X' if p == 1 else 'O'
        return new_s

    @staticmethod
    def current_player(s):
        n = s.count('_')
        if n % 2 == 1: return 1
        return 2

    @staticmethod
    def utc(node):
        v = node.q / (node.n + 1e-12) + sqrt(2 * log(node.parent.n + 1) / (node.n + 1e-12))
        return v



In [None]:
def safe_agent(board,turn):
    new = deepcopy(board)
    for i in new.empty_spots():
        new.board[i]==turn
        if (new.result==1 and turn==-1) or (new.result==2 and turn==1):
            return i
        new = board
    for i in new.empty_spots():
        new.board[i]==-1*turn
        if (new.result==2 and turn==-1) or (new.result==1 and turn==1):
            return i
        new = board
    return random.choice(np.array(new.empty_spots()))

In [None]:

def startgame(inital_config=['_']*9,games=100,turn_arr=[-1,1],agent=0):
    wins=0
    ties=0
    loss=0
    startwins=0
    startdraw=0
    startloss=0
    flag=0
    for i in range(games):
        t = TTT()
        t.board = deepcopy(inital_config)
        m = MCTS(t.board)
        turn = np.random.choice(turn_arr)
        flag=turn
        while t.result == 0:
            if turn==-1:
                if agent==0:
                    action = random.choice(t.empty_spots())
                    t.move(action,-1)
                elif agent==1 :
                    action = safe_agent(t,turn)
                    t.move(action,-1)
                elif agent==2:
                    time_start = time.time()
                    ii = m.mcts(new_board=t.board, mode='time', criteria=2000)
                    action=m.result_view()
                    t.move(action,-1)    
                turn*=-1
                t.checkresult()
            else:
                time_start = time.time()
                ii = m.mcts(new_board=t.board, mode='time', criteria=2000)
                action=m.result_view()
                t.move(action,1)
                turn*=-1
                t.checkresult()
            t.switch_player()
        if t.result==1:
            wins+=1
            if flag==1:
                startwins+=1
        elif t.result==2:
            if flag==1:
                startloss+=1
        elif t.result == 3:
            if flag==1:
                startdraw+=1
            ties+=1
        if i%20==0 and i!=0:
            print('--------------Summary so far. Game Number {}/{}---------------------'.format(i,games))
            print('\t \t' ,' Games', ' Wins', ' Losses', ' Draws')
            print('As Player 1 \t',startwins+startloss+startdraw,'\t',startwins,'\t',startloss,'\t',startdraw)
            print('As Player 2 \t',(wins+loss+ties)-(startwins+startloss+startdraw),'\t',wins-startwins,'\t',loss-startloss,'\t',ties-startdraw)
            print('In Total \t',wins+loss+ties,'\t',wins,'\t',loss,'\t',ties)

In [None]:
print('----------------------------Playing Against Random Agent----------------------------')
startgame(games=1000,turn_arr=[-1,1],agent=0)

In [None]:
print('----------------------------Playing Against Safe Agent------------------------------')
startgame(games=1000,turn_arr=[-1,1],agent=1)

In [None]:
print('----------------------------Playing Against MCTS Agent------------------------------')
startgame(games=1000,turn_arr=[-1,1],agent=2)