In [1]:
import numpy as np

## Environment

In [2]:
class TicTacToeEnv:
    """An environment for two-player tic-tac-toe."""

    def __init__(self, state=None, turn=None):
        self.state = np.zeros((3, 3), dtype=np.int) if state is None else state
        self.players = 2
        self.turn = 0 if turn is None else turn

    def reset(self):
        """Initialize a new game and return state and turn."""
        self.state = np.zeros((3, 3), dtype=np.int)
        self.turn = 0
        return self.state.copy(), self.turn

    def step(self, action):
        """Perform action and return new state, rewards, done, and turn."""
        assert self.state[action] == 0
        self.state[action] = (-1) ** self.turn
        rewards = np.zeros(self.players)
        if self.done:
            winner = self.winner()
            if winner == 0:
                rewards[0] = 1
                rewards[1] = -1
            elif winner == 1:
                rewards[0] = -1
                rewards[1] = 1
        self.turn = (self.turn + 1) % 2
        return self.state.copy(), rewards, self.done, self.turn

    def copy(self):
        copy = TicTacToeEnv()
        copy.state = self.state.copy()
        copy.turn = self.turn
        return copy

    def render(self):
        print(self.state)

    def winner(self):
        for row in range(3):
            if np.sum(self.state[row, :]) == 3:
                return 0
            if np.sum(self.state[row, :]) == -3:
                return 1
        for col in range(3):
            if np.sum(self.state[:, col]) == 3:
                return 0
            if np.sum(self.state[:, col]) == -3:
                return 1
        if np.sum(self.state[np.arange(3), np.arange(3)]) == 3:
            return 0
        if np.sum(self.state[np.arange(3), np.arange(3)]) == -3:
            return 1
        if np.sum(self.state[np.arange(3), np.arange(2, -1, -1)]) == 3:
            return 0
        if np.sum(self.state[np.arange(3), np.arange(2, -1, -1)]) == -3:
            return 1
        return None

    @property
    def actions(self):
        """The available actions for the current state."""
        if self.done:
            return []
        return [(a, b) for a in range(3) for b in range(3) if self.state[a, b] == 0]

    @property
    def done(self):
        """True if three in a row somewhere."""
        return self.winner() is not None or np.all(self.state != 0)

## Agent

In [3]:
def epsilon_greedy(epsilon=0.05):
    """Return an epsilon-greedy tree policy."""
    def policy(node, turn):
        if np.random.rand() < epsilon:
            node = np.random.choice(node.children)
        else:
            node = max(node.children, key=lambda node: node.value[turn])
        return node
    return policy


def ucb(c=np.sqrt(2)):
    """Return an upper confidence bound tree policy."""
    def policy(node, turn):
        def v(node):
            if node.visits == 0:
                return np.inf
            value = node.value[turn]
            value += c * np.sqrt(np.log(node.parent.visits)/node.visits)
            return value
        return max(node.children, key=v)
    return policy


class TreeNode:
    """A tree node for Monte Carlo tree search."""
    
    def __init__(self, parent, action, reward, done, turn, env):
        self.parent = parent
        self.children = []
        self.action = action
        self.reward = reward
        self.done = done
        self.turn = turn
        self.env = env
        self.visits = 0
        self.value = np.zeros(2)


class MCTSAgent:
    """A Monte Carlo tree search agent.

    Parameters
    ----------
    env_fn : function
        A function which maps states to new environments.
    tree_policy : function
        A function which maps (node, turn) to child node.
    rollouts : int, optional
        The number of rollouts to perform before choosing an action.

    """
    
    def __init__(self, env_fn, tree_policy=ucb(), rollouts=100):
        self.env_fn = env_fn
        self.tree_policy = tree_policy
        self.rollouts = rollouts
    
    def act(self, state, turn):
        env = self.env_fn(state=state, turn=turn)
        root = TreeNode(None, None, np.zeros(2), False, turn, env)
        for _ in range(self.rollouts):
            leaf = self.expand(root)
            value = self.simulate(leaf)
            self.backup(leaf, value)
        self.root = root
        return max(root.children, key=lambda node: node.visits).action
    
    def expand(self, node):
        """Return an unvisited or terminal leaf node following epsilon-greedy.
        
        Before returning, this function performs all possible actions from the
        leaf node and adds new nodes for them to the tree as children of the
        leaf node.
        """
        while node.visits != 0 and len(node.children) > 0:
            turn = node.turn
            node = self.tree_policy(node, turn)
        for action in node.env.actions:
            env = node.env.copy()
            state, rewards, done, turn = env.step(action)
            node.children.append(TreeNode(node, action, rewards, done, turn, env))
        return node
    
    def simulate(self, node):
        """Return one total reward from node following uniform random policy."""
        env = node.env.copy()
        done = node.done
        total_rewards = np.zeros(2)
        while not done:
            action = env.actions[np.random.choice(len(env.actions))]
            state, rewards, done, turn = env.step(action)
            total_rewards += rewards
        return total_rewards
    
    def backup(self, node, value):
        """Backup the return from a rollout from node."""
        while node != None:
            value += node.reward
            node.visits += 1
            node.value = (node.visits - 1)/node.visits * node.value + value/node.visits
            node = node.parent

In [4]:
class RandomAgent:
    """An agent that picks a random free space."""

    def act(self, state, turn):
        actions = [(a, b) for a in range(3) for b in range(3) if state[a, b] == 0]
        index = np.random.choice(len(actions))
        return actions[index]

In [5]:
class HumanAgent:
    """An agent controlled by a human player's input."""

    def act(self, state, turn):
        indices = input('Input action: ').replace('(', '').replace(')', '').split(',')
        return tuple(int(x) for x in indices)

## Testing

In [6]:
def run_episode(agents, env, render=False):
    """Run agents on env and return total rewards."""
    state, turn = env.reset()
    if render:
        env.render()
    total_reward = np.zeros(len(agents))
    done = False
    while not done:
        action = agents[turn].act(state, turn)
        state, rewards, done, turn = env.step(action)
        total_reward += rewards
        if render:
            env.render()
    return total_reward

In [7]:
%%time

env = TicTacToeEnv()
agents = [RandomAgent(), MCTSAgent(TicTacToeEnv, rollouts=100)]
np.mean([run_episode(agents, env) for _ in range(10)])

CPU times: user 8.33 s, sys: 30.5 ms, total: 8.36 s
Wall time: 8.35 s


In [10]:
env = TicTacToeEnv()
agents = [MCTSAgent(TicTacToeEnv, rollouts=1000), HumanAgent()]

In [13]:
state, turn = env.reset()
%prun -s cumulative action = agents[0].act(state, turn)

 