In [None]:
#|default_exp search

# Search Module

In [None]:
#hide
from nbdev.showdoc import *

## MCTS

In [None]:
#|export
import numpy as np

In [None]:
import igraph
import alphazero.display

In [None]:
#|export
class MCTSNode():
    def __init__(self, state, action_probs, valid_actions, branches, value, is_endstate, parent_index, index):
        self.action_probs = action_probs
        self.valid_actions = valid_actions
        self.branches = branches
        self.visits = 0
        self.value = value
        self.path_value = value
        self.is_endstate = is_endstate
        self.parent_index = parent_index
        self.index = index

In [None]:
#|export
rng = np.random.default_rng()

state_size=10
n_actions = 4
n_sims = 100
root_idx = 0
tree = {'counter':root_idx+1}

# initialize the root node of the search tree
tree[root_idx] = MCTSNode(state=rng.integers(0,n_actions, size=(state_size,state_size)),
                         action_probs=rng.random(size=n_actions),
                         valid_actions=rng.integers(0, 2, size=n_actions),
                         branches=[tree['counter']+i for i in range(n_actions)],
                         value=rng.random(),
                         is_endstate=False,
                         parent_index=None,
                         index=0)
tree['counter'] += n_actions


In [None]:
#|export
# sample simulation
for i in range(10):
    # reset to root node on sim start
    select_idx = parent_idx = root_idx
    value = 0.0
    
    # loop until find end or new state
    while tree.get(select_idx, None) is not None and not tree[select_idx].is_endstate:
        # select branch
        # parent idx keeps track of current node iot to specify a selected nodes parent on creation
        parent_idx = select_idx
        select_idx = rng.choice(tree[select_idx].branches)

    # if new state: expand tree
    if not tree.get(select_idx, None):
        tree[select_idx] = MCTSNode(state=rng.integers(0,n_actions, size=(state_size,state_size)),
                                    action_probs=rng.random(size=n_actions),
                                    valid_actions=rng.integers(0, 2, size=n_actions),
                                    branches=[tree['counter']+i for i in range(n_actions)],
                                    value=rng.random(),
                                    is_endstate=False,
                                    parent_index=parent_idx,
                                    index=select_idx)
        tree['counter'] += n_actions

    # in all cases: propagate value back down path to root
    value = tree[select_idx].value
    while parent_idx is not None:
        tree[parent_idx].path_value += value
        select_idx = parent_idx
        parent_idx = tree[select_idx].parent_index


In [None]:
#|hide
# tests TODO

# tree index counter increases by correct amount on each expansion

In [None]:
#|hide
# def create_node(state:np.ndarray,
#                 action_probs:np.ndarray,
#                 valid_actions,
#                 branches,
#                 visits:int=1,
#                 value:float=0.0,
#                 end:bool=False,
#                 parent_index:int=-1,
#                 index:int=-1):
    
#     return {'state':state,
#             'action_probls':action_probs,
#             'valid_actions':valid_actions,
#             'branches':branches,
#             'visits':visits,
#             'value':value,
#             'end':end,
#             'parent_index':parent_index,
#             'index':index}

# def create_random_node(state_size=4, n_actions=1, index=0, counter=0, parent_index=-1, end=False):
#     return create_node(state=rng.integers(0,n_actions, size=(state_size,state_size)),
#                        action_probs=rng.random(size=n_actions),
#                        valid_actions=rng.integers(0,2, size=n_actions),
#                        branches=[counter+i for i in range(n_actions)],
#                        value=rng.random(),
#                        end=False,
#                        parent_index=parent_index,
#                        index=index)

In [None]:
#|hide
# %%time

# n_sims = 100
# n_actions = 4

# tree = {}
# counter = 1 # keeps track of tree's next node index (equal to total nodes added)
# root_idx = 0 # keeps track of tree root (tree.root)
# graph_vis = igraph.Graph() # used for visualization

# # add node (tree.add_node)
# index = 0
# tree[index] = create_random_node(n_actions=n_actions, index=index, counter=counter)
# counter += n_actions
# graph_vis.add_vertex(index)
# print(list(graph_vis.vs))
# for i in tree[0]['branches']:
#     tree[i] = None
#     # graph_vis.add_vertex(i)
#     # graph_vis.add_edge(index, i)

# # simulations (tree.simulate)
# for i in range(n_sims):
#     # select action
#     select_idx = rng.choice(tree[index]['branches'])

#     # repeat until new or end state reached
#     while tree[select_idx] is not None:
#         index = select_idx
#         tree[index]['visits'] += 1

#         # end state found
#         if tree[index]['end']:
#             break

#         # select next node
#         select_idx = rng.choice(tree[index]['branches'])

#     # create new node (tree.add_node)
#     if tree[select_idx] is None:
#         parent = index
#         index = select_idx
#         tree[index] = create_random_node(n_actions=n_actions, index=index, counter=counter, parent_index=parent)
#         counter += n_actions
#         for i in tree[index]['branches']:
#             tree[i] = None
#             # graph_vis.add_vertex(i)
#             # graph_vis.add_edge(index, i)
       
#         graph_vis.add_vertex(select_idx)
#         print(select_idx)
#         print(list(graph_vis.vs))
#         graph_vis.add_edge(parent, select_idx)

#     # update path values
#     value = tree[index]['value']
#     visits = tree[index]['visits']
#     while index != root_idx:
#         index = tree[index]['parent_index']
#         tree[index]['value'] += value
#         value = tree[index]['value']
    


In [None]:
# alphazero.display.plot_tree(graph_vis, ['rt_circular','rt','tree','drl'][1])