In [7]:
import numpy as np
import random
from math import sqrt, log

In [8]:
class Environment:
    def __init__(self):
        self.state = np.zeros([2,3,3])
        self.player = 0 
    def reset(self):
        self.state = np.zeros([2,3,3])
        self.player = 0
    def start_point(self):
        self.state = np.zeros([2,3,3])
        self.state[0,1,0]=1
        self.state[0,1,1]=1
        self.state[1,0,1]=1
        self.state[1,1,2]=1
        self.player = 0
    def E_copy(self):
        copy = Environment()
        copy.player = self.player
        copy.state = np.copy(self.state)
        return copy
    def placement(self, placed):
        if np.any(self.state[:,placed[0],placed[1]]): 
            return
        if placed not in self.get():
            print("Illegal play")
            return
        self.state[self.player][placed[0],placed[1]] = 1
        self.player ^= 1

    def get(self):
        return np.argwhere(self.state[0]+self.state[1]==0).tolist()

    def result(self):
        ttt = self.state[self.player^1]
        c = np.any(np.sum(ttt,axis=0)==3)
        r = np.any(np.sum(ttt,axis=1)==3)
        d  = np.any(np.trace(ttt)==3)
        od = np.any(np.trace(np.flip(ttt,1))==3)
        return c or r or d or od
      
    def win(self):
        ttt = self.state[self.player^1]
        c = np.any(np.sum(ttt,axis=0)==3)
        r = np.any(np.sum(ttt,axis=1)==3)
        d  = np.any(np.trace(ttt)==3)
        od  = np.any(np.trace(np.flip(ttt,1))==3)
        temp = self.state
        if c or r or d or od:
            return True, self.player^1
        elif np.sum(temp) == 9:
            return True, None
        else:
            return False, None
            
    def display(self):
        ttt = self.state
        for i in range(3):
            for j in range(3):
                if ttt[0,i,j] ==1:
                    print(" X ", end = "")
                elif ttt[1,i,j] ==1:
                    print(" O ", end = "")
                else:
                    print(" ~ ", end = "")
            print()

In [14]:
class Tree:

    def __init__(self, parent=None, action=None, ttt=None):
        self.parent = parent
        self.ttt = ttt
        self.children = []
        self.wins = 0
        self.visits = 0
        self.remaining_grids = ttt.get()
        self.action = action
    def UCB(self):
        s = sorted(self.children, key=lambda c:c.wins/c.visits+sqrt(2*log(self.visits)/c.visits))
        return s[-1]
    def build(self, action, ttt):
        child = Tree(parent=self, action=action, ttt=ttt)
        self.remaining_grids.remove(action)
        self.children.append(child)
        return child
    def update(self, result):
        self.visits += 1
        self.wins += result

In [15]:
def Agent(rootstate, maxiters):

    root = Tree(ttt=rootstate)

    for i in range(maxiters):
        node = root
        ttt = rootstate.E_copy()
    while node.remaining_grids == [] and node.children != []:
        node = node.UCB()
        ttt.placement(node.action)
    if node.remaining_grids != []:
        a = random.choice(node.remaining_grids)
        ttt.placement(a)
        node = node.build(a, ttt.E_copy())
    while ttt.get() != [] and not ttt.result():
        ttt.placement(random.choice(ttt.get()))
    while node != None:
        result = ttt.result()
        if result:
            if node.ttt.player==ttt.player:
                result = 1
            else: 
                result = -1
        else: 
            result = 0
        node.update(result)
        node = node.parent
    s = sorted(root.children, key=lambda c:c.wins/c.visits)
    return s[-1].action


In [16]:
class Player:
    def __init__(self):
        pass

    def __str__(self):
        return "Player 1"
    def playa(self,x0):
        movez = random.choice(x0.get())
        return movez

In [27]:
def simulation(ttt,p1):
    players = {0: "Player 1", 1: "MCTS UCT "}
    turn = 0
    count = 0
    wins = 0
    loses = 0
    draws =  0
    first_step = 1
    while count < 100:
        if turn ==0 and first_step == 0 :
            x0 = ttt
            action = p1.playa(x0)
            ttt.placement(action)
        elif first_step == 1:
            ttt.placement([0,0])
            first_step = 0
        else:
            x0 = ttt
            action = Agent(x0,1000)
            ttt.placement(action)
        
        is_over, winner = ttt.win()
        if is_over:
            if winner != None:

                if winner == 0:
                    wins += 1
                if winner == 1:
                    loses += 1
            
            else:
                draws += 1
            count +=1
        
            ttt.start_point()
            first_step = 1

            turn =1
            expected_r = (wins - loses)/count
            print(" win probability for game : ",count, " is ",wins/count)
        

        turn ^= 1

In [28]:
if __name__ == '__main__':
    ttt = Environment()
    ttt.start_point()
    p1 = Player()
    print("Tic tac toe with MCTS UCT : ")
    ttt.display()
    simulation(ttt,p1)


Tic tac toe with MCTS UCT : 
 ~  O  ~ 
 X  X  O 
 ~  ~  ~ 
 win probability for game :  1  is  1.0
 win probability for game :  2  is  1.0
 win probability for game :  3  is  1.0
 win probability for game :  4  is  1.0
 win probability for game :  5  is  1.0
 win probability for game :  6  is  1.0
 win probability for game :  7  is  1.0
 win probability for game :  8  is  0.875
 win probability for game :  9  is  0.8888888888888888
 win probability for game :  10  is  0.9
 win probability for game :  11  is  0.9090909090909091
 win probability for game :  12  is  0.9166666666666666
 win probability for game :  13  is  0.8461538461538461
 win probability for game :  14  is  0.8571428571428571
 win probability for game :  15  is  0.8666666666666667
 win probability for game :  16  is  0.875
 win probability for game :  17  is  0.8823529411764706
 win probability for game :  18  is  0.8888888888888888
 win probability for game :  19  is  0.8947368421052632
 win probability for game :  20 

In [6]:

# import numpy as np


# class Node:
#     def __init__(self, data):
#         self.a0 = None
#         self.a1 = None
#         self.a2 = None
#         self.a3 = None
#         self.a4 = None
#         self.parent = None
#         self.children = None
#         self.data = data


# # state 0
# root = Node(0)
# root.a0 = Node(0)
# root.a1 = Node(0)
# root.a2 = Node(0)
# root.a3 = Node(0)
# root.a4 = Node(0)
# root.a0.parent = root
# root.a1.parent = root
# root.a2.parent = root
# root.a3.parent = root
# root.a4.parent = root

# a0 = root.a0
# a1 = root.a1
# a2 = root.a2
# a3 = root.a3
# a4 = root.a4
# root.children = [a0, a1, a2, a3, a4]
# # action 0 branch
# a0.a1 = Node(0)
# a0.a2 = Node(1)  # winning node
# a0.a3 = Node(0)
# a0.a4 = Node(1)  # winning node
# a0.a1.parent = a0
# a0.a2.parent = a0
# a0.a3.parent = a0
# a0.a4.parent = a0
# a0.children = [a0.a1, a0.a2, a0.a3, a0.a4]

# a0.a1.a2 = Node(1)  # winning node
# a0.a1.a3 = Node(0)
# a0.a1.a4 = Node(1)  # winning node
# a0.a3.a1 = Node(0)
# a0.a3.a2 = Node(-1)  # losing node
# a0.a3.a4 = Node(1)  # winning node
# a0.a1.a2.parent = a0.a1
# a0.a1.a3.parent = a0.a1
# a0.a1.a4.parent = a0.a1
# a0.a1.children = [a0.a1.a2, a0.a1.a3, a0.a1.a4]
# a0.a3.a1.parent = a0.a3
# a0.a3.a2.parent = a0.a3
# a0.a3.a4.parent = a0.a3
# a0.a3.children = [a0.a3.a1, a0.a3.a2, a0.a3.a4]

# # # action 1 branch
# # # state 2
# a1.a0 = Node(0)
# a1.a2 = Node(1)  # winning node
# a1.a3 = Node(0)
# a1.a4 = Node(0)
# a1.a0.parent = a1
# a1.a2.parent = a1
# a1.a3.parent = a1
# a1.a4.parent = a1
# a1.children = [a1.a0, a1.a2, a1.a3, a1.a4]

# # # state 3
# a1.a0.a2 = Node(1)  # winning node
# a1.a0.a3 = Node(0)
# a1.a0.a4 = Node(1)  # winning node
# a1.a3.a0 = Node(0)
# a1.a3.a2 = Node(1)  # winning node
# a1.a3.a4 = Node(0)
# a1.a4.a0 = Node(1)  # winning node
# a1.a4.a2 = Node(1)  # winning node
# a1.a4.a3 = Node(0)
# a1.a0.a2.parent = a1.a0
# a1.a0.a3.parent = a1.a0
# a1.a0.a4.parent = a1.a0
# a1.a0.children = [a1.a0.a2, a1.a0.a3, a1.a0.a4]
# a1.a3.a0.parent = a1.a3
# a1.a3.a2.parent = a1.a3
# a1.a3.a4.parent = a1.a3
# a1.a3.children = [a1.a3.a0, a1.a3.a2, a1.a3.a4]
# a1.a4.a0.parent = a1.a4
# a1.a4.a2.parent = a1.a4
# a1.a4.a3.parent = a1.a4
# a1.a4.children = [a1.a4.a0, a1.a4.a2, a1.a4.a3]
# # action 2 branch
# a2.a0 = Node(1)  # winning node
# a2.a1 = Node(1)  # winning node
# a2.a3 = Node(0)
# a2.a4 = Node(0)
# a2.a0.parent = a2
# a2.a1.parent = a2
# a2.a3.parent = a2
# a2.a4.parent = a2
# a2.children = [a2.a0, a2.a1, a2.a3, a2.a4]

# a2.a3.a0 = Node(-1)  # losing node
# a2.a3.a1 = Node(1)  # winning node
# a2.a3.a4 = Node(-1)  # losing node
# a2.a4.a0 = Node(1)  # winning node
# a2.a4.a1 = Node(1)  # winning node
# a2.a4.a3 = Node(-1)  # losing node
# a2.a3.a0.parent = a2.a3
# a2.a3.a1.parent = a2.a3
# a2.a3.a4.parent = a2.a3
# a2.a3.children = [a2.a3.a0, a2.a3.a1, a2.a3.a4]
# a2.a4.a0.parent = a2.a4
# a2.a4.a1.parent = a2.a4
# a2.a4.a3.parent = a2.a4
# a2.a4.children = [a2.a4.a0, a2.a4.a1, a2.a4.a3]

# # action 3 branch
# # state 2
# a3.a0 = Node(0)
# a3.a4 = Node(0)
# a3.a2 = Node(0)
# a3.a1 = Node(0)
# a3.a0.parent = a3
# a3.a4.parent = a3
# a3.a2.parent = a3
# a3.a1.parent = a3
# a3.children = [a3.a0, a3.a1, a3.a2, a3.a4]

# # state 3
# a3.a0.a1 = Node(0)
# a3.a0.a2 = Node(-1)  # losing node
# a3.a0.a4 = Node(1)  # winning node
# a3.a1.a0 = Node(0)
# a3.a1.a2 = Node(1)  # winning node
# a3.a1.a4 = Node(0)
# a3.a2.a0 = Node(-1)  # losing node
# a3.a2.a1 = Node(1)  # winning node
# a3.a2.a4 = Node(-1)  # losing node
# a3.a4.a0 = Node(1)  # winning node
# a3.a4.a1 = Node(0)
# a3.a4.a2 = Node(-1)  # losing node
# a3.a0.a1.parent = a3.a0
# a3.a0.a2.parent = a3.a0
# a3.a0.a4.parent = a3.a0
# a3.a0.children = [a3.a0.a1, a3.a0.a2, a3.a0.a4]
# a3.a1.a0.parent = a3.a1
# a3.a1.a2.parent = a3.a1
# a3.a1.a4.parent = a3.a1
# a3.a1.children = [a3.a1.a0, a3.a1.a2, a3.a1.a4]
# a3.a2.a0.parent = a3.a2
# a3.a2.a1.parent = a3.a2
# a3.a2.a4.parent = a3.a2
# a3.a2.children = [a3.a2.a0, a3.a2.a1, a3.a2.a4]
# a3.a4.a0.parent = a3.a4
# a3.a4.a1.parent = a3.a4
# a3.a4.a2.parent = a3.a4
# a3.a4.children = [a3.a4.a0, a3.a4.a1, a3.a4.a2]

# # action 4 branch
# a4.a0 = Node(1)  # winning node
# a4.a1 = Node(0)
# a4.a2 = Node(0)
# a4.a3 = Node(0)
# a4.a0.parent = a4
# a4.a1.parent = a4
# a4.a2.parent = a4
# a4.a3.parent = a4
# a4.children = [a4.a0, a4.a1, a4.a2, a4.a3]

# a4.a1.a0 = Node(0)
# a4.a1.a2 = Node(1)  # winning node
# a4.a1.a3 = Node(0)
# a4.a2.a0 = Node(1)  # winning node
# a4.a2.a1 = Node(1)  # winning node
# a4.a2.a3 = Node(-1)  # losing node
# a4.a3.a0 = Node(1)  # winning node
# a4.a3.a1 = Node(0)
# a4.a3.a2 = Node(-1)  # losing node
# a4.a1.a0.parent = a4.a1
# a4.a1.a2.parent = a4.a1
# a4.a1.a3.parent = a4.a1
# a4.a1.children = [a4.a1.a0, a4.a1.a2, a4.a1.a3]
# a4.a2.a0.parent = a4.a2
# a4.a2.a1.parent = a4.a2
# a4.a2.a3.parent = a4.a2
# a4.a2.children = [a4.a2.a0, a4.a2.a1, a4.a2.a3]
# a4.a3.a0.parent = a4.a3
# a4.a3.a1.parent = a4.a3
# a4.a3.a2.parent = a4.a3
# a4.a3.children = [a4.a3.a0, a4.a3.a1, a4.a3.a2]

# # list of leafnodes
# leaf_nodes = [a4.a3.a0, a4.a3.a1, a4.a3.a2, a4.a2.a0, a4.a2.a1, a4.a2.a3, a4.a1.a0, a4.a1.a2, a4.a1.a3, a4.a0, a3.a4.a0, a3.a4.a1, a3.a4.a2, a3.a2.a0, a3.a2.a1,
#               a3.a2.a4, a3.a1.a0, a3.a1.a2, a3.a1.a4, a3.a0.a1, a3.a0.a2, a3.a0.a4, a2.a4.a0, a2.a4.a1, a2.a4.a3, a2.a3.a0,
#               a2.a3.a1, a2.a3.a4, a2.a0, a2.a1, a1.a4.a0, a1.a4.a2, a1.a4.a3, a1.a3.a0,
#               a1.a3.a2, a1.a3.a4, a1.a0.a2, a1.a0.a3, a1.a0.a4, a1.a2, a0.a3.a1, a0.a3.a2, a0.a3.a4, a0.a1.a2,
#               a0.a1.a3, a0.a1.a4, a0.a2, a0.a4]


# win_probs = []


# def get_Q_value(x):
#     global win_probs
#     actions = np.array([a0, a1, a2, a3, a4])
#     q_values = []
#     q_value = 1
#     for action in actions:
#         a = action
#         r = a.data
#         r = int(r)

#         if not a.children:
#             return r

#         x2_nodes = []
#         x2_nodes_w_children = []  # for when there are more children, saves index
#         x3_nodes = []
#         x2_wins = 0
#         x3_wins = 0
#         total_loss = 0
#         for node in a.children:  # where a is 1 of 5 actions
#             x2_nodes.append(node.data)
#             if (node.data == 1):  # win
#                 x2_wins += 1
#             else:
#                 x2_nodes_w_children.append(node.data)
#                 if(node.children):
#                     for node in node.children:
#                         x3_nodes.append(node.data)
#                         if(node.data == 1):  # win
#                             x3_wins += 1
#                         elif (node.data == -1):  # loss
#                             total_loss += 1

#         x3_len = len(x3_nodes)
#         x2_len = len(x2_nodes)
#         total_wins = x2_wins + x3_wins
#         win_prob = 0
#         if a.children:
#             win_prob = ((x2_wins/x2_len)+(x3_wins/x3_len))
#         else:
#             win_prob = (x2_wins/x2_len)

#         win_probs.append(round(win_prob, 4))
#         q_value = total_wins-total_loss + ((1/5) * win_prob + (1/3) * win_prob)
#         q_value = q_value/(x2_len+x3_len)
#         q_values.append(q_value)

#     return q_values


# for i in range(100):
#     print(get_Q_value(1))
#     print(win_probs)
#     print('======\n')


[0.4533333333333333, 0.4945868945868946, 0.2533333333333333, 0.01111111111111111, 0.25925925925925924]
[1.0, 0.8056, 1.0, 0.3333, 0.6944]

[0.4533333333333333, 0.4945868945868946, 0.2533333333333333, 0.01111111111111111, 0.25925925925925924]
[1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944]

[0.4533333333333333, 0.4945868945868946, 0.2533333333333333, 0.01111111111111111, 0.25925925925925924]
[1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944]

[0.4533333333333333, 0.4945868945868946, 0.2533333333333333, 0.01111111111111111, 0.25925925925925924]
[1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944]

[0.4533333333333333, 0.4945868945868946, 0.2533333333333333, 0.01111111111111111, 0.25925925925925924]
[1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 1.0, 0.8056, 1.0, 0.3333, 0.6944, 