In [None]:
# default_exp core

# AlphaZero Core

In [None]:
#hide
from nbdev.showdoc import *

# Search Module

## MCTS

In [None]:
import numpy as np

In [None]:
import igraph
import alphazero.display

In [None]:
rng = np.random.default_rng()

Data Structure

Search Tree
- The search tree is stored as a 1-level dictionary of index:node key:value pairs.
- Keys are dropped when no longer part of the active search tree.
- Keys are initialized & point to a null value when their parent node is created.
- A node is created when it's key is first accessed, replacing its null value.

Node
- dictionary containing data on a state.
- contains keys indexing into parent and branch nodes.

In [None]:
def create_node(state:np.ndarray,
                action_probs:np.ndarray,
                valid_actions,
                branches,
                visits:int=1,
                value:float=0.0,
                end:bool=False,
                parent_index:int=-1,
                index:int=-1):
    
    return {'state':state,
            'action_probls':action_probs,
            'valid_actions':valid_actions,
            'branches':branches,
            'visits':visits,
            'value':value,
            'end':end,
            'parent_index':parent_index,
            'index':index}

def create_random_node(state_size=4, n_actions=1, index=0, counter=0, parent_index=-1, end=False):
    return create_node(state=rng.integers(0,n_actions, size=(state_size,state_size)),
                       action_probs=rng.random(size=n_actions),
                       valid_actions=rng.integers(0,2, size=n_actions),
                       branches=[counter+i for i in range(n_actions)],
                       value=rng.random(),
                       end=False,
                       parent_index=parent_index,
                       index=index)

In [None]:
%%time

n_sims = 100
n_actions = 4

tree = {}
counter = 1 # keeps track of tree's next node index (equal to total nodes added)
root_idx = 0 # keeps track of tree root (tree.root)
graph_vis = igraph.Graph() # used for visualization

# add node (tree.add_node)
index = 0
tree[index] = create_random_node(n_actions=n_actions, index=index, counter=counter)
counter += n_actions
graph_vis.add_vertex(index)
print(list(graph_vis.vs))
for i in tree[0]['branches']:
    tree[i] = None
    # graph_vis.add_vertex(i)
    # graph_vis.add_edge(index, i)

# simulations (tree.simulate)
for i in range(n_sims):
    # select action
    select_idx = rng.choice(tree[index]['branches'])

    # repeat until new or end state reached
    while tree[select_idx] is not None:
        index = select_idx
        tree[index]['visits'] += 1

        # end state found
        if tree[index]['end']:
            break

        # select next node
        select_idx = rng.choice(tree[index]['branches'])

    # create new node (tree.add_node)
    if tree[select_idx] is None:
        parent = index
        index = select_idx
        tree[index] = create_random_node(n_actions=n_actions, index=index, counter=counter, parent_index=parent)
        counter += n_actions
        for i in tree[index]['branches']:
            tree[i] = None
            # graph_vis.add_vertex(i)
            # graph_vis.add_edge(index, i)
       
        graph_vis.add_vertex(select_idx)
        print(select_idx)
        print(list(graph_vis.vs))
        graph_vis.add_edge(parent, select_idx)

    # update path values
    value = tree[index]['value']
    visits = tree[index]['visits']
    while index != root_idx:
        index = tree[index]['parent_index']
        tree[index]['value'] += value
        value = tree[index]['value']
    


In [None]:
# alphazero.display.plot_tree(graph_vis, ['rt_circular','rt','tree','drl'][1])

`compute_state` computes new game states given an `action` and `state`. It's called whenever a new state is encountered. Its output is stored as a new `node`'s `state` and `available_actions`.

In [None]:
def compute_state(state, action_idx, valid_actions):
    assert valid_actions[action_idx] == 1

    pass
    # TODO


## Exploitation and Exploration functions

The *exploitation* function favors high returns and consistently winning paths. *Exploration* initially favors new paths and eventually favors winning paths.

Reference: [Deep Learning and the Game of Go [14.2.1]](https://livebook.manning.com/book/deep-learning-and-the-game-of-go/chapter-14/39) | [MIT 16.412J s16 Cognitive Robotics: Advanced 4. Monte Carlo Tree Search](https://www.youtube.com/watch?v=xmImNoDc9Z4)

```
exploit(s,a) = total_val(s) / visit_count(s,a)

explore(s,a) = c_puct * action_probabilities(s,a) * sqrt(visit_count(previous_node)) / (1 + visit_count(s,a))
```

`total_val` is the sum of all node values on the current path up to the current node. `visit_counts` is the number of visits to the current node.

In [None]:
rng = np.random.default_rng(0)
x = rng.dirichlet([0.03]*3, )
x, x.size