Impl. of https://www.nature.com/nature/journal/v518/n7540/full/nature14236.html

PDF [here](https://pdfs.semanticscholar.org/340f/48901f72278f6bf78a04ee5b01df208cc508.pdf)



In [None]:
%pylab inline

import matplotlib.cm as cm
import matplotlib.patches as patches

from collections import defaultdict
from six.moves import zip_longest
# For some reason Notebook doesn't like this...
#from builtins import input
import numpy as np
import random, sys, pickle, os, time

from IPython.core.debugger import set_trace

# Hacky py3 backwards compatibility
try:
    input = raw_input
except NameError:
   pass

In [None]:
class TicTacToeBoard(object):
    # 0 1 2
    # 3 4 5
    # 6 7 8
    winning_spots = np.array([
        [0, 1, 2], [3, 4, 5], [6, 7, 8], # Horizontal
        [0, 3, 6], [1, 4, 7], [2, 5, 8], # Vertical
        [0, 4, 8], [2, 4, 6]             # Diagonal
        ])
    
    board_format = '\n'.join([
        ' {} | {} | {} ',
        '---+---+---',
        ' {} | {} | {} ',
        '---+---+---',
        ' {} | {} | {} ',
        ])

    def __init__(self, prev=None, action=None):
        if prev is not None:
            self.marks = prev.marks.copy()
            self.marks[action] = prev.active_player
            self.active_player = 'X' if prev.active_player == 'O' else 'O'
        else:
            self.active_player = 'X'
            self.marks = np.array(['_']*9)

    def __repr__(self):
        return ''.join(self.marks) + ',' + self.active_player

    def __str__(self):
        return TicTacToeBoard.board_format.format(*self.marks)

    def __eq__(self, other):
        return isinstance(other, self.__class__) \
            and np.array_equal(self.marks, other.marks) \
            and self.active_player == other.active_player

    def __hash__(self):
        return hash(repr(self))

    @staticmethod
    def from_repr(s):
        out = TicTacToeBoard()
        out.active_player = s[-1]
        out.marks = np.array(list(s[:-2]))
        return out

    def render(self):
        print(self.__str__())

    # returns (next_state, reward, done)
    def step(self, action):
        # type: (int) -> (TicTacToeBoard, float, bool)
        next_state = TicTacToeBoard(self, action)

        if self.marks[action] != '_':
            return (next_state, -1, True)
        elif next_state.check_win(self.active_player):
            return (next_state, 1, True)
        elif next_state.is_full():
            return (next_state, 0, True)

        return (next_state, 0, False)

    def check_win(self, player):
        slices = self.marks[TicTacToeBoard.winning_spots]
        return (slices == player).all(axis=1).any()

    def is_full(self):
        return (self.marks != '_').all()

In [None]:
class TabularAgent(object):
    def __init__(self, num_actions, alpha=0.75, gamma=1, epsilon=1, default_Q=0):
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.default_Q = default_Q
        self.num_actions = num_actions
        self.action_buffers = [{} for _ in range(self.num_actions)]

    def max_action(self, state):
        # type: (TicTacToeBoard) -> int
        predictions = [self.buffer_value(ndx, state) for ndx in range(self.num_actions)]
        return np.argmax(predictions)

    def choose_action(self, state):
        # type: (TicTacToeBoard) -> int
        if random.random() > self.epsilon:
            return random.choice(range(self.num_actions))
        return self.max_action(state)

    def buffer_value(self, ndx, state):
        # type: (int, TicTacToeBoard) -> float
        return self.action_buffers[ndx].get(state, self.default_Q)

    def update_buffer(self, buffer_ndx, state, new_val):
        # type: (int, TicTacToeBoard, float) -> None
        buff = self.action_buffers[buffer_ndx]
        if state in buff:
            buff[state] = (1-self.alpha)*buff[state] + self.alpha*new_val
        else:
            buff[state] = new_val

    def train(self, history):
        raise NotImplementedError()

# - - - - - - - - - - - - - - - - - - - - - 

class MonteCarloAgent(TabularAgent):
    def train(self, history):
        if len(history) == 0:
            return
        prev_action, return_, _ = history[-1]
        for (action, reward, state) in reversed(history[:-1]):
            self.update_buffer(prev_action, state, return_)

            prev_action = action
            if reward is not None:
                return_ += reward
                
class TemporalDifferenceAgent(TabularAgent):
    def new_val(self, history, ndx):
        raise NotImplementedError()

    def train(self, history):
        if len(history) == 1:
            return
        for i in range(len(history)-2):
            (_, _, state), (action, _, _) = history[i:i+2]
            self.update_buffer(action, state, self.new_val(history, i))
        (_, _, state), (action, reward, _) = history[-2:]
        self.update_buffer(action, state, reward)

# - - - - - - - - - - - - - - - - - - - - - 

class QLearningAgent(TemporalDifferenceAgent):
    def new_val(self, history, ndx):
        (_, _, state), (action, reward, next_state) = history[ndx:ndx+2]
        next_action = self.max_action(next_state)
        return reward + self.gamma * self.buffer_value(next_action, next_state)

class SarsaAgent(TemporalDifferenceAgent):
    def new_val(self, history, ndx):
        (_, _, state), (action, reward, next_state), (next_action, _, _) = history[ndx:ndx+3]
        return reward + self.gamma * self.buffer_value(next_action, next_state)


In [None]:
def play_vs_human(agent, state=None):
    def print_state(state):
        fig = figure(figsize=[3,3])
        ax = fig.add_subplot(111)

        def draw_cell(pos, mark, val):
            y, x = divmod(pos, 3)
            if mark == 'X':
                ax.plot([x+.2, x+.8], [y+.8, y+.2], 'k', lw=2.0)
                ax.plot([x+.2, x+.8], [y+.2, y+.8], 'k', lw=2.0)
            elif mark == 'O':
                ax.add_patch(patches.Circle((x+.5,y+.5), .35, ec='k', fc='none', lw=2.0))
            else:
                color = cm.viridis((val+1)/2.)
                ax.add_patch(patches.Rectangle((x,y), 1, 1, ec='none', fc=color))
                ax.text(x+.5, y+.5, '%.2f'%val, ha='center', va='center') 

        for i in range(9):
            draw_cell(i, state.marks[i], agent.buffer_value(i,state))

        ax.set_position([0,0,1,1])
        ax.set_axis_off()

        ax.set_xlim(0,3)
        ax.set_ylim(3,0)

        for x in range(1,3):
            ax.plot([x, x], [0,3], 'k', lw=2.0)
            ax.plot([0,3], [x, x], 'k', lw=2.0)
        show()

    if state is None:
        state = TicTacToeBoard()
    
    # Flip a coin for who goes first
    compToMove = random.random() > 0.5
    
    while True:
        print_state(state)
        
        if compToMove:
            state, reward, done = state.step(agent.choose_action(state))
        else:
            move = int(input('Choose your move [1-9]: ')) - 1
            state, reward, done = state.step(move)
            
        if done:
            pronoun = 'I ' if compToMove else 'You '
            print('Tie.' if reward == 0  else pronoun+'win.' if reward > 0  else pronoun+'lose.')
            break
            
        compToMove = not compToMove
    print('========')

In [None]:
# TRAIN
fname = 'tictac.txt'

def progressbar(callback, iters, refresh_rate=2.0):
    prev_clock = time.time()
    start_clock = prev_clock

    for i in range(iters):
        callback(i)
        curr_clock = time.time()
        if (curr_clock-prev_clock)*refresh_rate >= 1:
            sys.stdout.write('\r[ %s / %s ]' % (i, iters))
            sys.stdout.flush()
            prev_clock = curr_clock

    clearstr = ' '*len('[ %s / %s ]' % (iters, iters))
    sys.stdout.write('\r%s\r' % clearstr)
    sys.stdout.flush()

    return time.time() - start_clock

# Assumes zero-sum, two-player, sequential-turn game
def train_episode(agent, state=None):
    if state is None:
        # Start at a random previously encountered state
        keys = agent.action_buffers[0].keys()
        set_trace()
        if len(keys) > 0:
            state = random.choice(keys)
        else:
            state = TicTacToeBoard()

    first_player = state.active_player
    history = [(None, None, state)]

    # Play out a game
    while True:
        action = agent.choose_action(state)
        state, reward, done = state.step(action)
        history.append((action, reward, state))
        if done:
            break

    # `history` stores things like [(None, None, s1), (p1a1, p1r1, s2), (p2a1, p2r1, s3), (p1a2, p1r2, s4), ...]
    # `player_history` transforms that to [(None, None, s1), (p1a1, p1r1-p2r1, s3), (p1a2, p1r2-p2r2, s5), ...]
    # You subtract the reward given to the other player because of the assumption of it being a zero-sum game.
    def player_history(history):
        # e.g.  grouped('ABCDEFG', 3, 'x') --> 'ABC' 'DEF' 'Gxx'
        def grouped(iterable, n, fillvalue=None):
            "Collect data into fixed-length chunks or blocks"
            # https://docs.python.org/2/library/itertools.html#recipes
            args = [iter(iterable)] * n
            return zip_longest(fillvalue=fillvalue, *args)

        out = [(None, None, history[0][2])]
        for (action, reward, state), (_, other_reward, other_state) \
                                in grouped(history[1:], 2, (None,)*3):
            if other_reward is None:
                out.append((action, reward, state))
            else:
                out.append((action, reward-other_reward, other_state))
        return out
    
    # split the history into separate histories for each player
    first_history = player_history(history)
    second_history = player_history(history[1:])
    
    # train to improve performance for each player
    agent.train(first_history)
    agent.train(second_history)

    
if os.path.isfile(fname):
    print('Loading agent from %s...' % fname)
    agent = pickle.load(open(fname, 'rb'))
else:
    agent = QLearningAgent(num_actions=9, epsilon=0.8, default_Q=2)

init_state = TicTacToeBoard() # Always start from actual inital state
#init_state = None # Random restarts

episodes = 10000
print('Training for %s episodes...' % episodes)
progressbar(lambda x: train_episode(agent, init_state), episodes)

print('Saving agent to %s...' % fname)
pickle.dump(agent, open(fname, 'wb'))


In [None]:
# PLAY
fname = 'tictac.txt'

agent = pickle.load(open(fname, 'rb'))
agent.epsilon = 0.99
try:
    while True:
        play_vs_human(agent)
except Exception as e:
    print(e)