# Monte Carlo Tree Search



In [None]:
import sys

# resolve path for notebook
sys.path.append('../')

In [None]:
# ipykernel runs in an asyncio loop, so we need to allow nesting
import nest_asyncio
nest_asyncio.apply()

In [None]:
import time
import math
import copy
import asyncio
import numpy as np

from collections import defaultdict
from environments.QuestEnvironment import QuestEnvironment

In [None]:
# code adapted from: https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1

In [None]:
class MCTS_Node:

    def __init__(self, state=None, action=None, reward=0, is_terminal=False):

        if state is not None:
            self.state = state['glyphs_crop']
            self.coords = self._get_coordinates(state)
            
        self.action = action
        self.reward = reward
        self.is_terminal = is_terminal

    def _get_coordinates(self, state):
        # the first two positions in the blstats
        # array give us the row and column coords
        col = state['blstats'][0]
        row = state['blstats'][1]
        return tuple([col, row])


class MCTS:

    def __init__(
            self,
            env,
            exploration_weight=1.0,
            num_simulations=50
        ):

        self.env = env
        self.rewards = defaultdict(int)
        self.visit_counts = defaultdict(int)
        self.children = dict()
        self.exploration_weight = exploration_weight
        self.num_simulations = num_simulations
        self.search_depth = 10

    def _select(self, node):

        # create a list to hold the path
        path = []

        # loop until we find the path
        while True:

            # append the current node to the path
            path.append(node)

            # check for unexplored or terminal nodes
            if node not in self.children or not self.children[node]:
                return path

            # get the unexplored children
            unexplored = self.children[node] - self.children.keys()

            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path

            # move one layer deeper in the tree
            node = self._uct_select(node)

    def _uct_select(self, node):

        # All children of node should already be expanded:
        assert all(n in self.children for n in self.children[node])

        log_N_vertex = math.log(self.visit_counts[node])

        def uct(n):
            return self.rewards[n] / self.visit_counts[n] + self.exploration_weight * math.sqrt(
                log_N_vertex / self.visit_counts[n]
            )

        return max(self.children[node], key=uct) 

    def _expand(self, node):

        # if the node has already been expanded, 
        # the we can just return
        if node in self.children:
            return

        # add the node's children to the list
        self.children[node] = self._find_children(node)

    async def _find_child_async(self, action):

        # create a clone of the current environment
        env_clone = self.env.clone()

        # take the action in the environment
        state, reward, is_terminal, _ = env_clone.step(action)

        # convert to a tree node
        child_node = MCTS_Node(state, action, reward, is_terminal)

        # return the child node
        return child_node

    def _find_children(self, node):

        children = set()

        for action in range(self.env.action_space.n):

            state, reward, is_terminal, _ = self.env.step(action)

            child_node = MCTS_Node(state, action, reward, is_terminal)

            children.add(child_node)

            #time_start = time.perf_counter()

            self.env.revert()

            #time_stop = time.perf_counter()
            #print(f"{time_stop - time_start:0.4f}")

        return children


    # def _find_children(self, node):

    #     loop = asyncio.get_event_loop()

    #     tasks = []

    #     for action in range(self.env.action_space.n):
    #         tasks.append(self._find_child_async(action))

    #     time_start = time.perf_counter()

    #     children = loop.run_until_complete(asyncio.gather(*tasks))

    #     time_stop = time.perf_counter()
    #     print(f"{time_stop - time_start:0.4f}")

    #     loop.close()

    #     return set(children)


    # def _find_children(self, node):

    #     children = set()

    #     # check if we are done
    #     if not node.is_terminal:
            
    #         # take all actions
    #         for action in range(self.env.action_space.n):

    #             # create a clone of the current environment
    #             #env_clone = self.env.clone()

    #             # take the action in the environment
    #             #state, reward, is_terminal, _ = env_clone.step(action)

    #             # convert to a tree node
    #             #child_node = MCTS_Node(state, reward, is_terminal)
    #             child_node = MCTS_Node(action = action)

    #             # add to the set
    #             children.add(child_node)

    #     return children

    def _simulate(self, node):

        # create a clone of the current environment
        self.env.revert()

        for _ in range(self.search_depth):

            if node.is_terminal:
                return node.reward

            # choose a random action
            action = np.random.choice([*range(self.env.action_space.n)])

            # take the action in the environment
            state, reward, is_terminal, _ = self.env.step(action)

            # convert to a tree node
            node = MCTS_Node(state, reward, is_terminal)

        # if we have not yet reached the end
        # return the latest reward
        return node.reward

    def _backpropogate(self, path, reward):

        for node in reversed(path):

            self.rewards[node] += reward
            self.visit_counts[node] += 1
            

    def rollout(self, state):

        node = MCTS_Node(state=state, action=None)

        start_time = time.perf_counter()

        # run a pre-determined number of simulations
        for _ in range(self.num_simulations):

            # select a path
            path = self._select(node)

            # get the leaf node
            leaf = path[-1]

            # expand the leaf node one level
            self._expand(leaf)

            # explore to the end for a
            # random child of this leaf
            reward = self._simulate(leaf)

            # back-propogate the reward
            self._backpropogate(path, reward)

        stop_time = time.perf_counter()
        print(f"{stop_time-start_time:0.4f}")

        return node

    def choose(self, node):
        if node.is_terminal:
            raise RuntimeError(f"choose called on terminal node {node}")

        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.visit_counts[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.rewards[n] / self.visit_counts[n]  # average reward

        return max(self.children[node], key=score)


# create the environment
env = QuestEnvironment().create(
    reward_lose = -10,
    reward_win = 10,
    penalty_step = -0.002,
    penalty_time = -0.001
)

# create the simulation environment
env_s = QuestEnvironment().create(
    reward_lose = -10,
    reward_win = 10,
    penalty_step = -0.002,
    penalty_time = -0.001
)

# create the search tree
tree = MCTS(env_s, exploration_weight = 1.0, num_simulations=1)

_ = env.reset()
# coords = _get_coordinates(state)
# state = state['glyphs_crop']

for i in range(100):

    for j in range(50):

        state = env_s.revert(env.action_history)

    # # do a rollout with update
    # node = tree.rollout(state)

    # # choose the action
    # node = tree.choose(node)

    # # take the action
    # _ = env.step(node.action)

    action = np.random.choice([*range(env.action_space.n)])

    _ = env.step(action)

    env.render()

