## Monte Carlo Tree Search on Starter Battle

In this experiment, I will be exploring the use of Monte Carlo Tree Search (MCTS) on the Starter Battle environment. The goal is to see how well MCTS can perform in this environment and how it compares to the DQN model from the [initial_pokemon_battleing_agent](./initial_pokemon_battleing_agent.ipynb) experiment notebook.

In [57]:
# Ensure relative imports work correctly
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import random

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Space

from services.starter_pokemons import starter_df

### A model based approach

The DQN model from the earlier experiment would be considered a model free approach to solving the starter battle. For my learning outcomes however, I am obligated to explore a model based approach. MCTS is my approach of choice. My plan is to implement the transition model as a tree and use MCTS to search through the tree to find the best move. I was inspired to use MCTS by google's AlphaGo which uses MCTS to search through the game tree to find the best move.

### How MCTS works
![Image from wikipedia about how MCTS works](https://upload.wikimedia.org/wikipedia/commons/a/a6/MCTS_Algorithm.png)
*Image from wikipedia about how MCTS works*

### How a tree representing pokemon states differers from a traditional game tree

TODO: describe the fact that a tree representing pokemon game does not have a set root node, which is different from a traditional game tree like one for chess (which always has the same starting state)

### MCTS requires a perfect information game

TODO: describe that MCTS requires a perfect information game, which essentially is the case for pokemon. However, I stated before that in this research I would like to treat the game as if its not. So I am breaking my own rules at this point, but I am willing to do so for learning purposes. 

## Inspiration for tree implementation

https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1

## State Space

In [45]:
hp_space = gym.spaces.Discrete(starter_df['hp'].max() + 1)
attack_space = gym.spaces.Discrete(starter_df['attack'].max() + 1)
defense_space = gym.spaces.Discrete(starter_df['defense'].max() + 1)
# sp_atk_space = gym.spaces.Discrete(starter_df['sp. atk'].max() + 1)
# sp_def_space = gym.spaces.Discrete(starter_df['sp. def'].max() + 1)
speed_space = gym.spaces.Discrete(starter_df['speed'].max() + 1)

In [58]:
stat_stage_space = gym.spaces.Box(low=0, high=12, shape=(6,), dtype=np.int8)
def map_stat_stages(stat_stages: list[int]) -> np.ndarray:
    if len(stat_stages) != 6:
        raise ValueError('Expected exactly 6 stat stages')
    
    # map from -6 / 6 to 0 / 12
    return np.array(stat_stages) + 6

## Action Space

In [None]:
import poke_battle_sim as pb

action_mappings = {
    0: ('move', 0, 0),
    1: ('move', 0, 1),
}
action_space = gym.spaces.Discrete(len(action_mappings))

def get_action(action: int, trainer: pb.Trainer) -> tuple[str, int]:
    action = action_mappings[action][0]
    return [
        action[0], # The actual action
        trainer.poke_list[action[1]].moves[action[2]].name
    ]

## Tree Implementation

Bellow is a class implementing the tree I will be using for the MCTS algorithm.

In [None]:
"""
A minimal implementation of Monte Carlo tree search (MCTS) in Python 3
Luke Harold Miles, July 2019, Public Domain Dedication
See also https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1
"""
from abc import ABC, abstractmethod
from collections import defaultdict
import math


class MCTS:
    "Monte Carlo tree searcher. First rollout the tree then choose a move."

    def __init__(self, exploration_weight=1):
        self.Q = defaultdict(int)  # total reward of each node
        self.N = defaultdict(int)  # total visit count for each node
        self.children = dict()  # children of each node
        self.exploration_weight = exploration_weight

    def choose(self, node):
        "Choose the best successor of node. (Choose a move in the game)"
        if node.is_terminal():
            raise RuntimeError(f"choose called on terminal node {node}")

        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.N[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.Q[n] / self.N[n]  # average reward

        return max(self.children[node], key=score)

    def do_rollout(self, node):
        "Make the tree one layer better. (Train for one iteration.)"
        path = self._select(node)
        leaf = path[-1]
        self._expand(leaf)
        reward = self._simulate(leaf)
        self._backpropagate(path, reward)

    def _select(self, node):
        "Find an unexplored descendent of `node`"
        path = []
        while True:
            path.append(node)
            if node not in self.children or not self.children[node]:
                # node is either unexplored or terminal
                return path
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            node = self._uct_select(node)  # descend a layer deeper

    def _expand(self, node):
        "Update the `children` dict with the children of `node`"
        if node in self.children:
            return  # already expanded
        self.children[node] = node.find_children()

    def _simulate(self, node):
        "Returns the reward for a random simulation (to completion) of `node`"
        invert_reward = True
        while True:
            if node.is_terminal():
                reward = node.reward()
                return 1 - reward if invert_reward else reward
            node = node.find_random_child()
            invert_reward = not invert_reward

    def _backpropagate(self, path, reward):
        "Send the reward back up to the ancestors of the leaf"
        for node in reversed(path):
            self.N[node] += 1
            self.Q[node] += reward
            reward = 1 - reward  # 1 for me is 0 for my enemy, and vice versa

    def _uct_select(self, node):
        "Select a child of node, balancing exploration & exploitation"

        # All children of node should already be expanded:
        assert all(n in self.children for n in self.children[node])

        log_N_vertex = math.log(self.N[node])

        def uct(n):
            "Upper confidence bound for trees"
            return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt(
                log_N_vertex / self.N[n]
            )

        return max(self.children[node], key=uct)


class Node(ABC):
    """
    A representation of a single board state.
    MCTS works by constructing a tree of these Nodes.
    Could be e.g. a chess or checkers board state.
    """

    @abstractmethod
    def find_children(self):
        "All possible successors of this board state"
        return set()

    @abstractmethod
    def find_random_child(self):
        "Random successor of this board state (for more efficient simulation)"
        return None

    @abstractmethod
    def is_terminal(self):
        "Returns True if the node has no children"
        return True

    @abstractmethod
    def reward(self):
        "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc"
        return 0

    @abstractmethod
    def __hash__(self):
        "Nodes must be hashable"
        return 123456789

    @abstractmethod
    def __eq__(node1, node2):
        "Nodes must be comparable"
        return True

The tree starts with a special game node.
- In board games, the starting state is usually constant. This is not the case for pokemon battles.
- Thus our root node will just be a special node
- Its children will be composed of all possible starting states, once we populate the tree

In [74]:
class RootNode(Node):
    def __init__(self):
        self.children = set()

    def find_children(self):
        return self.children
    
    def find_random_child(self):
        return random.choice(list(self.children))
    
    def is_terminal(self):
        return False
    
    def reward(self):
        return 0

The tree will follow a similair structure as seen in the image above
- Each vertex will represent a state of the game
- Their edges will represent the actions possible from the given state

In [None]:
class TrainerState():
    def __init__(self, hp: int, attack: int, defense: int, speed: int, stat_stages):
        self.hp = hp
        self.attack = attack
        self.defense = defense
        self.speed = speed
        self.stat_stages = stat_stages

In [None]:
class StateNode(Node):
    def __init__(self, agent: TrainerState, opponent: TrainerState):
        self.agent = agent
        self.opponent = opponent
        self.children = set()
    
    def find_children(self):
        return self.children
    
    def find_random_child(self):
        return random.choice(list(self.children))
    
    def is_terminal(self):
        return self.agent.hp <= 0 or self.opponent.hp <= 0
    
    def reward(self):
        # TODO find a more complex way to get rewards, as this will just result in finding the quickest way to win
        # The Q values arising from this wont be very interasting
        if self.agent.hp <= 0:
            return -1
        elif self.opponent.hp <= 0:
            return 1
        
        return 0

    def __hash__(self):
        return hash((self.agent, self.opponent))

    def __eq__(self, other):
        return self.agent == other.agent and self.opponent == other.opponent

## Conclusion

It seems to me that building a transition model (i.e. using a model based approach) is most viable when it is easy (i.e. wont take to long) to exhaustively turn all the rules and actions of a problem into code. For example, chess has pretty simple and relativly small set of rules. Their are only so many moves a piece can make, and their are not that many pieces. They have very predictable behavior, which makes it easy to implement a transition model.

Pokemon on the other hand, has 493 pieces and 215 unique move effects. Not to mention the fact a trainer can have items to its disposale. This makes it very hard to implement a complete transition model. Utilizing other peoples work (like for example, using the effect methods from `poke_battle_sim.util.process_move`) could make implementing a complete transition model easier. 