# An AI agent plays tic-tac-toe (part 2): speeding up recursive functions using memoization
*Where we increase the speed of a brute force tree search to make it practical for use in reinforcement learning*

*This article is part of a series that lets a computer play tic-tac-toe using reinforcement learning. You can find [all the articles here](https://towardsdatascience.com/tagged/rl-series-paul). The goal is to provide a complete implementation that you can really pick apart and learn reinforcement learning from. It is probably best to read the articles in order. The article including all the code [can be found on Github](https://github.com/PaulHiemstra/memoise_paper/blob/master/memoise_paper.ipynb).*

In part 1 of this series we implemented a tree search minimax algorithm to serve as the opponent for our Reinforcement Learning (RL) agent. The conclusion was that although it worked, the algorithm was far too slow to be used in training our RL agent. The goal of part 2 is to speed up our minimax algorithm significantly. 

One possible solution strategy is to minimize the size of the tree. For example, the game of tic-tac-toe is symmetrical across the board, so we can eliminate roughly half of the tree outright. However, I chose to leave the algorithm and tree as is, and focus more on solving this problem with an advanced programming technique called [memoization](https://youtu.be/P8Xa2BitN3I). The general idea is that when a function is called, the result of the function is stored in a dictionary where the key is equal to the function call arguments. Next time the function is called using these arguments, the result from the dictionary is simply returned. In our case, this would reduce getting the optimal move from recursively searching a tree to looking up a value in a dictionary. 

Let us start first by loading the tree back into memory, and loading our minimax code. Note that the github repository includes a Python script that generates this tree. 

In [10]:
import dill
from treelib import Node, Tree
import numpy as np
import pandas as pd

with open('tree_tactoe_3x3.pkl', 'rb') as f:
    TicToe_3x3 = dill.load(f)    

def minmax_tt(tree, current_id, is_max):
    current_node = tree[current_id]                     # Find the tree element we are now
    if current_node.data.is_endstate():                 # Are we at the end of the game?
        return current_node.data.get_value()            # Return the value
    children_of_current_id = tree.children(current_id)  # Determine the children
    scores = [minmax_tt(tree, child.identifier, not is_max) for child in children_of_current_id]   # Recursively run this function on each of the children
    if is_max:                                          # Return the max or min score depending on which player we are
        return max(scores)
    else:
        return min(scores)
    
def determine_move(tree, current_id, is_max):
    '''
    Given a state on the board, what is the best next move? 
    '''
    potential_moves = tree.children(current_id)
    moves = [child.identifier[-1] for child in potential_moves]
    raw_scores = [minmax_tt(tree, child.identifier, not is_max) for child in potential_moves]
    if is_max:
        return moves[raw_scores.index(max(raw_scores))]
    else:
        return moves[raw_scores.index(min(raw_scores))]

Now we can request the next move for the minimizing player assuming the maximizing player has made the `a` move to start with:

In [3]:
import time

start = time.time()
determine_move(TicToe_3x3, 'a', is_max=False)
time.time()-start

2.9204039573669434

which takes around 3 seconds on my machine. 

Online I found [the following memoization implementation](https://www.python-course.eu/python3_memoization.php). It creates a memoization class that we can then use to [decorate](https://www.datacamp.com/community/tutorials/decorators-python) our recursive minimax tree search. This nicely separates the memoization functionality from the actual function that is being called to do the work. Note that I exclude the first argument from the dictionary to prevent the key from becoming too large, slowing down the memoization process. 

In [11]:
class Memoize_tree:
    def __init__(self, fn):
        self.fn = fn
        self.memo = {}                                      # Create our empty memo buffer

    def __call__(self, *args):
        function_call_hash = args[1:]                       # Note we skip the first argument, this is the tree that is always the same. Adding this would slow down the hashing procedure
        if function_call_hash not in self.memo:             # Check if the function has been called before
            self.memo[function_call_hash] = self.fn(*args)  # Store the result of the function call
        return self.memo[function_call_hash]                # return the result from the memo dictionary

@Memoize_tree   # Decorate the minimax algorithm
def minmax_tt(tree, current_id, is_max):
    current_node = tree[current_id] 
    if current_node.data.is_endstate():
        return current_node.data.get_value()
    children_of_current_id = tree.children(current_id)
    scores = [minmax_tt(tree, child.identifier, not is_max) for child in children_of_current_id]
    if is_max:
        return max(scores)
    else:
        return min(scores)

This technique works similarly to the function operator style I used in [this article](https://towardsdatascience.com/advanced-functional-programming-for-data-science-building-code-architectures-with-function-dd989cc3b0da) in separating the core functions from the helper functions. A function operator would have been a good alternative here to the decorator class. 

With the memoization in place, we can test if it actually performs better. In the code below we call the function twice and compute how much faster the memoised version is:

In [12]:
import time

start = time.time()
determine_move(TicToe_3x3, 'a', is_max=False)   # First time the search takes long
first_call = time.time()-start

start = time.time()
determine_move(TicToe_3x3, 'a', is_max=False)   # Second time, memoization kicks in
first_call / (time.time()-start)

49962.9296875

Nice, this yields around a 50k times increase in speed. The final step is to force all the possible board states through the minimax function to fill up the memoisation buffer with all the function calls:

In [18]:
from tqdm import tqdm
import itertools

all_states = []
for length in range(1,9):
    tree_states = [''.join(state) for state in list(itertools.permutations(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'], r=length))]
    all_states.extend(tree_states)

for state in tqdm(all_states):
    try:
        move = determine_move(TicToe_3x3, state, False) 
    except: # Skip any board states that cannot occur
        pass 

100%|██████████| 623529/623529 [00:02<00:00, 298466.79it/s]


Which takes around 30 seconds on my machine. 

With all the tree searches precomputed, `determine_move` is now fast enough to quickly run the required number of tic-tac-toe games. In the next part we will implement the RL algorithm called Q-learning.