In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
from copy import deepcopy
import random
import math

class TicTacToe:
    """TicTacToe Board."""
    
    EMPTY = ' '
    P1 = 'X'
    P2 = 'O'
    
    def __init__(self):
        """Initialize a TicTacToe board."""
        
        self.initial_state = []
        for i in range(3):
            self.initial_state.append([' ']*3)
            
        self.players = (self.P1, self.P2)
        
        self.reset()
    
    @classmethod
    def switch_player(cls, player):
        """Class method that, given a player, returns its opponent"""
        
        return cls.P2 if player == cls.P1 else cls.P1
        

    @classmethod
    def check_termination(cls, state):
        """
        Termination check.
        
        Returns
        -------
        done: bool
            A boolean indicating the termination
            
        winning_player: str
            The name of the winning player. The empty cell is returned in case of a draw.
        """

        done = False
        winning_player = ' '
        
        for i in range(3):
            first_elem = state[i][0]
            if first_elem != cls.EMPTY:
                done = all(state[i][j]==first_elem for j in range(3))
                if done:
                    return done, first_elem
                
            first_elem = state[0][i]
            if first_elem != cls.EMPTY:
                done = all(state[j][i]==first_elem for j in range(3))
                if done:
                    return done, first_elem
                
        first_elem = state[0][0]
        if first_elem != cls.EMPTY:
            done = all(state[i][i]==first_elem for i in range(3))
            if done:
                return done, first_elem
        
        first_elem = state[0][2]
        if first_elem != cls.EMPTY:
            done = all(state[i][2-i]==first_elem for i in range(3))
            if done:
                return done, first_elem
        
        done = all(state[i][j] != cls.EMPTY for i in range(3) for j in range(3))
        
        return done, winning_player

    @classmethod
    def get_available_moves(cls, state):
        """Class method that, given a state, returns the free cells"""
        
        return [(i,j) for i in range(3) for j in range(3) if state[i][j] == ' ']
        
    @property    
    def available_moves(self):
        """Available moves at current state"""
        
        return self.get_available_moves(self.current_state)
    
    @classmethod
    def transition_function(cls, state, action, player):
        """Transition function"""
        
        assert state[action[0]][action[1]] == cls.EMPTY
        
        state[action[0]][action[1]] = player
        player = cls.switch_player(player)
        return state, player
        
    def step(self, action):
        """Step function"""
        
        self.done, winning_player = self.check_termination(self.current_state)
        assert self.done == False
        
        self.current_state, self.current_player = self.transition_function(self.current_state, action, self.current_player)
        
        reward = 0
        self.done, winning_player = self.check_termination(self.current_state)
        if winning_player == self.player:
            reward = 1
        if winning_player == self.switch_player(self.player):
            reward = -1
        
        return self.current_state, reward, self.done, self.current_player
    
    def render(self, state=None):
        """Render function"""

        if state is None:
            state = self.current_state
        
        for i in range(3):
            print('-----------')
            print("|".join((f" {x} " for x in state[i])))
        print('-----------')
    
    def reset(self, player=None):
        """
        Resets the environment.
        
        Player 'X' always moves first. If 'O' is selected, the first move is made
        """

        
        self.done = False
        self.player = player
        if self.player is None:
            self.player = self.P1
        
        self.current_state = deepcopy(self.initial_state)
        self.current_player = self.P1
        
        if player == self.P2:
            self.step(random.choice(self.available_moves))
            
        return self.current_state, self.current_player

In [3]:
class MCTSNode:
    """MCTS Node."""
    
    def __init__(self, state, player, parent=None, action=None):
        """Initialize a node."""
        
        self.state = state
        self.player = player
        self.parent = parent
        self.action = action
        self.children = []
        self.untried_actions = TicTacToe.get_available_moves(self.state)
        self.n = 0
        self.w = 0
        self.is_terminal = TicTacToe.check_termination(self.state)[0]

    @property
    def fully_expanded(self):
        return len(self.untried_actions) == 0
    
    def expand(self):
        """Pick an untried action, evaluate it, generate the node for the resulting state (also add it to the children) and return it."""
        
        action = self.untried_actions.pop()
        
        next_state, next_player = TicTacToe.transition_function(deepcopy(self.state), action, self.player)
        
        child_node = MCTSNode(next_state, next_player, parent=self, action=action)
        
        self.children.append(child_node)
        
        return child_node
    
    def rollout(self):
        """Until termination, move randomly. Return the result (winning player)"""
        
        state = self.state
        player = self.player
        done, result = TicTacToe.check_termination(state)
        while not done:
            possible_actions = TicTacToe.get_available_moves(state)
            action = random.choice(possible_actions)
            state, player = TicTacToe.transition_function(deepcopy(state), action, player)
            done, result = TicTacToe.check_termination(state)
        return result
    
    def backpropagate(self, result):
        """Backprop the result of a rollout up to the root node: For each node in the path update the visits and the number of wins"""
        
        self.n += 1
        if self.parent:
            if result == self.parent.player:
                self.w += 1
            elif result != ' ':
                self.w -= 1
            self.parent.backpropagate(result)
            
    def traverse(self):
        """Traverse the nodes until an unexpanded one is found or termination is reached"""
        
        node = self
        
        while node.fully_expanded and not node.is_terminal:
            node = node.best_uct_child()
            
        if node.is_terminal:
            return node
        
        return node.expand()
    
    def win_ratio(self):
        """Win Ratio of a node"""
        
        return self.w/self.n
    
    def uct(self):
        """UCT value of a node"""
        
        return self.win_ratio() + math.sqrt(2*math.log(self.parent.n)/self.n)
    
    def best_child(self):
        """Return the best child (the one with the highest win ratio)"""
        
        best_win_ratio, child = max((self.children[i].win_ratio(), i) for i in range(len(self.children)))
#         import numpy as np
#         m = np.zeros((3,3))
        
#         for c in self.children:
#             m[c.action] = c.win_ratio()*c.n
            
#         for i in range(3):
#             print(' '.join([str(round(e,2)) for e in m[i]]))
        
        return self.children[child]
        
    def best_uct_child(self):
        """Return the best child according to UCT"""

        best_win_ratio, child = max((self.children[i].uct(), i) for i in range(len(self.children)))

        return self.children[child]
        

In [4]:
def mcts(state, player, iters=5000):
    root = MCTSNode(deepcopy(state), player)
    for i in range(iters):
        leaf = root.traverse()
        simulation_result = leaf.rollout()
        leaf.backpropagate(simulation_result)

    return root.best_child().action, root        

In [5]:
player = 'X'

In [6]:
env = TicTacToe()

# Random Agents (X and O)

In [7]:
rewards = []
for i in range(10):
    state, cur_player = env.reset(player)
    done = False
    total_reward = 0
    while not done:
        action = random.choice(env.available_moves)
        state, reward, done, cur_player = env.step(action)
        total_reward += reward
    rewards.append(total_reward)

print('Mean Reward over 10 episodes:', sum(rewards)/len(rewards))

Mean Reward over 10 episodes: 0.4


# MCTS Agent (X) and Random Agend (O)

In [8]:
state, cur_player = env.reset(player)
done = False

env.render()

while not done:
    if cur_player == player:
        action, root = mcts(deepcopy(state), player)
    else:
        action = random.choice(env.available_moves)
    state, reward, done, cur_player = env.step(action)
    env.render()

-----------
   |   |   
-----------
   |   |   
-----------
   |   |   
-----------
-----------
   |   |   
-----------
   | X |   
-----------
   |   |   
-----------
-----------
   |   |   
-----------
   | X |   
-----------
 O |   |   
-----------
-----------
   |   |   
-----------
   | X |   
-----------
 O | X |   
-----------
-----------
   |   | O 
-----------
   | X |   
-----------
 O | X |   
-----------
-----------
   | X | O 
-----------
   | X |   
-----------
 O | X |   
-----------


In [9]:
rewards = []
for i in range(10):
    state, cur_player = env.reset(player)
    done = False

    total_reward = 0
    while not done:
        if cur_player == player:
            action, root = mcts(deepcopy(state), player)
        else:
            action = random.choice(env.available_moves)
        state, reward, done, cur_player = env.step(action)
        total_reward += reward
    rewards.append(total_reward)

print('Mean Reward over 10 episodes:', sum(rewards)/len(rewards))

Mean Reward over 10 episodes: 1.0


# MCTS Agents (X and O)

In [10]:
state, cur_player = env.reset(player)
done = False

env.render()

while not done:
    action, root = mcts(deepcopy(state), cur_player)
    state, reward, done, cur_player = env.step(action)
    env.render()

-----------
   |   |   
-----------
   |   |   
-----------
   |   |   
-----------
-----------
   |   |   
-----------
   | X |   
-----------
   |   |   
-----------
-----------
   |   |   
-----------
   | X |   
-----------
   |   | O 
-----------
-----------
   |   |   
-----------
   | X | X 
-----------
   |   | O 
-----------
-----------
   |   |   
-----------
 O | X | X 
-----------
   |   | O 
-----------
-----------
   |   |   
-----------
 O | X | X 
-----------
   | X | O 
-----------
-----------
   | O |   
-----------
 O | X | X 
-----------
   | X | O 
-----------
-----------
   | O | X 
-----------
 O | X | X 
-----------
   | X | O 
-----------
-----------
   | O | X 
-----------
 O | X | X 
-----------
 O | X | O 
-----------
-----------
 X | O | X 
-----------
 O | X | X 
-----------
 O | X | O 
-----------


In [11]:
rewards = []
for i in range(10):
    state, cur_player = env.reset(player)
    done = False

    total_reward = 0
    while not done:
        action, root = mcts(deepcopy(state), cur_player)
        state, reward, done, cur_player = env.step(action)
        total_reward += reward
    rewards.append(total_reward)

print('Mean Reward over 10 episodes:', sum(rewards)/len(rewards))

Mean Reward over 10 episodes: 0.0
