## Monte Carlo Tree Search on Starter Battle

In this experiment, I will be exploring the use of Monte Carlo Tree Search (MCTS) on the Starter Battle environment. The goal is to see how well MCTS can perform in this environment and how it compares to the DQN model from the [initial_pokemon_battleing_agent](./initial_pokemon_battleing_agent.ipynb) experiment notebook.

In [1]:
# Ensure relative imports work correctly
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import random

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Space
import poke_battle_sim as pb

from services.starter_pokemons import starter_stats_df

### A model based approach

The DQN model from the earlier experiment would be considered a model free approach to solving the starter battle. For my learning outcomes however, I am obligated to explore a model based approach. MCTS is my approach of choice. My plan is to implement the transition model as a tree and use MCTS to search through the tree to find the best move. I was inspired to use MCTS by google's AlphaGo which uses MCTS to search through the game tree to find the best move.

### How MCTS works
![Image from wikipedia about how MCTS works](https://upload.wikimedia.org/wikipedia/commons/a/a6/MCTS_Algorithm.png)
*Image from wikipedia about how MCTS works*

### How a tree representing pokemon states differers from a traditional game tree

TODO: describe the fact that a tree representing pokemon game does not have a set root node, which is different from a traditional game tree like one for chess (which always has the same starting state)

### MCTS requires a perfect information game

TODO: describe that MCTS requires a perfect information game, which essentially is the case for pokemon. However, I stated before that in this research I would like to treat the game as if its not. So I am breaking my own rules at this point, but I am willing to do so for learning purposes. 

## Inspiration for tree implementation

https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1

In [2]:
"""
A minimal implementation of Monte Carlo tree search (MCTS) in Python 3
Luke Harold Miles, July 2019, Public Domain Dedication
See also https://en.wikipedia.org/wiki/Monte_Carlo_tree_search
https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1
"""
from abc import ABC, abstractmethod
from collections import defaultdict
import math


class MCTS:
    "Monte Carlo tree searcher. First rollout the tree then choose a move."

    def __init__(self, exploration_weight=1):
        self.Q = defaultdict(int)  # total reward of each node
        self.N = defaultdict(int)  # total visit count for each node
        self.children = dict()  # children of each node
        self.exploration_weight = exploration_weight

    def choose(self, node):
        "Choose the best successor of node. (Choose a move in the game)"
        if node.is_terminal():
            raise RuntimeError(f"choose called on terminal node {node}")

        if node not in self.children:
            return node.find_random_child()

        def score(n):
            if self.N[n] == 0:
                return float("-inf")  # avoid unseen moves
            return self.Q[n] / self.N[n]  # average reward

        return max(self.children[node], key=score)

    def do_rollout(self, node):
        "Make the tree one layer better. (Train for one iteration.)"
        path = self._select(node)
        leaf = path[-1]
        self._expand(leaf)
        reward = self._simulate(leaf)
        self._backpropagate(path, reward)

    def _select(self, node):
        "Find an unexplored descendent of `node`"
        path = []
        while True:
            path.append(node)
            print("Current path", path)
            if node not in self.children or not self.children[node]:
                # node is either unexplored or terminal
                return path
            unexplored = self.children[node] - self.children.keys()
            if unexplored:
                n = unexplored.pop()
                path.append(n)
                return path
            node = self._uct_select(node)  # descend a layer deeper

    def _expand(self, node):
        "Update the `children` dict with the children of `node`"
        if node in self.children:
            return  # already expanded
        self.children[node] = node.find_children()

    def _simulate(self, node):
        "Returns the reward for a random simulation (to completion) of `node`"
        invert_reward = True
        while True:
            if node.is_terminal():
                reward = node.reward()
                return 1 - reward if invert_reward else reward
            node = node.find_random_child()
            invert_reward = not invert_reward

    def _backpropagate(self, path, reward):
        "Send the reward back up to the ancestors of the leaf"
        for node in reversed(path):
            self.N[node] += 1
            self.Q[node] += reward
            reward = 1 - reward  # 1 for me is 0 for my enemy, and vice versa

    def _uct_select(self, node):
        "Select a child of node, balancing exploration & exploitation"

        # All children of node should already be expanded:
        assert all(n in self.children for n in self.children[node])

        log_N_vertex = math.log(self.N[node])

        def uct(n):
            "Upper confidence bound for trees"
            return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt(
                log_N_vertex / self.N[n]
            )

        return max(self.children[node], key=uct)


class Node(ABC):
    """
    A representation of a single board state.
    MCTS works by constructing a tree of these Nodes.
    Could be e.g. a chess or checkers board state.
    """

    @abstractmethod
    def find_children(self):
        "All possible successors of this board state"
        return set()

    @abstractmethod
    def find_random_child(self):
        "Random successor of this board state (for more efficient simulation)"
        return None

    @abstractmethod
    def is_terminal(self):
        "Returns True if the node has no children"
        return True

    @abstractmethod
    def reward(self):
        "Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc"
        return 0

    @abstractmethod
    def __hash__(self):
        "Nodes must be hashable"
        return 123456789

    @abstractmethod
    def __eq__(node1, node2):
        "Nodes must be comparable"
        return True

## State Space

In [3]:
hp_space = gym.spaces.Discrete(starter_stats_df['hp'].max() + 1)
attack_space = gym.spaces.Discrete(starter_stats_df['attack'].max() + 1)
defense_space = gym.spaces.Discrete(starter_stats_df['defense'].max() + 1)
# sp_atk_space = gym.spaces.Discrete(starter_df['sp. atk'].max() + 1)
# sp_def_space = gym.spaces.Discrete(starter_df['sp. def'].max() + 1)
speed_space = gym.spaces.Discrete(starter_stats_df['speed'].max() + 1)

In [4]:
stat_stage_space = gym.spaces.Box(low=0, high=12, shape=(6,), dtype=np.int8)
def map_stat_stages(stat_stages: list[int]) -> np.ndarray:
    if len(stat_stages) != 6:
        raise ValueError('Expected exactly 6 stat stages')
    
    # map from -6 / 6 to 0 / 12
    return np.array(stat_stages) + 6

## Action Space

In [5]:
import poke_battle_sim as pb

action_mappings = {
    0: ('move', 0, 0),
    1: ('move', 0, 1),
}
action_space = gym.spaces.Discrete(len(action_mappings))

def get_action(action: int, trainer: pb.Trainer) -> tuple[str, int]:
    action = action_mappings[action][0]
    return [
        action[0], # The actual action
        trainer.poke_list[action[1]].moves[action[2]].name
    ]

### Move resolving

From https://bulbapedia.bulbagarden.net/wiki/Damage

![Image from bulbapedia about how damage is calculated](./mcts_on_starter_battle/damage_formula.png)

Where:
- **Level** is the level of the attacking Pokémon. If the used move is Beat Up, L is instead the level of the Pokémon performing the strike.
- **A** is the effective Attack stat of the attacking Pokémon if the used move is a physical move, or the effective Special Attack stat of the attacking Pokémon if the used move is a special move (for a critical hit, negative Attack or Special Attack stat stages are ignored). If the used move is Beat Up, A is instead the base Attack of the Pokémon performing the strike.
- **D** is the effective Defense stat of the target if the used move is a physical move, or the effective Special Defense stat of the target if the used move is a special move (for a critical hit, positive Defense or Special Defense stat stages are ignored). If the used move is Beat Up, D is instead the base Defense of the target.
- **Power** is the effective power of the used move.
- **Burn** is 0.5 if the attacker is burned, its Ability is not Guts, and the used move is a physical move, and 1 otherwise.
- **Screen** is 0.5 if the used move is physical and Reflect is present on the target's side of the field, or special and Light Screen is present. For a Double Battle, Screen is instead 2/3; however, if in a Double Battle when the move is executed, the only Pokémon on the target's side of the field is the target (for moves with only one target), or there is only one target when the move is executed (for moves with more than one target), Screen remains as 0.5. Screen is 1 otherwise or if the used move lands a critical hit.
- **Targets** is 0.75 in Double Battles if the used move has more than one target (provided there is more than one such target when the move is executed, regardless of whether the move actually hits or can hit all the targets), and 1 otherwise.
- **Weather** is 1.5 if a Water-type move is being used during rain or a Fire-type move during harsh sunlight, and 0.5 if a Water-type move is used during harsh sunlight or a Fire-type move during rain, or SolarBeam during any non-clear weather besides harsh sunlight, and 1 otherwise or if any Pokémon on the field have the Ability Cloud Nine or Air Lock.
- **FF** is 1.5 if the used move is Fire-type, and the attacker's Ability is Flash Fire that has been activated by a Fire-type move, and 1 otherwise.
- **Critical** is 2 for a critical hit, 3 if the move lands a critical hit and the attacker's Ability is Sniper, and 1 otherwise. It is always 1 if Future Sight or Doom Desire is used, the target's Ability is Battle Armor or Shell Armor, the target is under the effect of Lucky Chant, or if the battle is the first one against StarlyDP.
- **Item** is 1.3 if the attacker is holding a Life Orb, $1+ \frac {n}{10} $ if the attacker is holding a Metronome, where n is the amount of times the same move has been successfully and consecutively used, up to 10, and 1 otherwise.
- **First** is 1.5 if the used move was stolen with Me First.
- **random** is realized a random integer from 85 to 100, inclusive, divided by 100. random is always 1 if Spit Up is used.
- STAB is the same-type attack bonus. This is equal to 1.5 if the move's type matches any of the user's types, 2 if the user of the move additionally has Adaptability, and 1 if otherwise.
- **Type1** is the type effectiveness of the used move against the target's first type (or only type, if it only has a single type). This can be 0.5 (not very effective), 1 (normally effective), or 2 (super effective). If the used move is Struggle, Future Sight, Beat Up, or Doom Desire, both Type1 and Type2 are always 1.
- **Type2** is the type effectiveness of the used move against the target's second type. This can be 0.5 (not very effective), 1 (normally effective), or 2 (super effective). If the target only has a single type, Type2 is 1.
- **SRF** is 0.75 if the used move is super effective, the target's Ability is Solid Rock or Filter, and the attacker's Ability is not Mold Breaker, and 1 otherwise.
- **EB** is 1.2 if the used move is super effective and the attacker is holding an Expert Belt, and 1 otherwise.
- **TL** is 2 if the used move is not very effective and the attacker's Ability is Tinted Lens, and 1 otherwise.
- **Berry** is 0.5 if the used move is super effective and the target is holding the Berry that weakens it, or Normal-type and the target is holding a Chilan Berry, and 1 otherwise.

Since we do not need everything for the starter battle, we can simplify the formula to:

$Damage = \left( \frac{ \left( \frac{2 \times Level}{5} + 2 \right) \times Power \times A / D }{50} \right) $

<!-- Where:
- **Level** is the level of the attacking Pokémon. If the used move is Beat Up, L is instead the level of the Pokémon performing the strike.
- **A** is the effective Attack stat of the attacking Pokémon if the used move is a physical move, or the effective Special Attack stat of the attacking Pokémon if the used move is a special move (for a critical hit, negative Attack or Special Attack stat stages are ignored). If the used move is Beat Up, A is instead the base Attack of the Pokémon performing the strike.
- **D** is the effective Defense stat of the target if the used move is a physical move, or the effective Special Defense stat of the target if the used move is a special move (for a critical hit, positive Defense or Special Defense stat stages are ignored). If the used move is Beat Up, D is instead the base Defense of the target.
- **Power** is the effective power of the used move. -->

In [6]:
from poke_battle_sim.core.move import Move

def simulate_move(
        attacking_starter: pb.Pokemon, 
        defending_starter: pb.Pokemon, 
        move: Move
    ):
    if move.power == 0 or move.power is None or isinstance(move.power, str):
        if move.name == 'growl':
            defending_starter.stat_stages[1] = defending_starter.stat_stages[1] - 1
            if defending_starter.stat_stages[1] < -6:
                defending_starter.stat_stages[1] = -6
        elif move.name == 'leer':
            defending_starter.stat_stages[2] = defending_starter.stat_stages[2] - 1
            if defending_starter.stat_stages[2] < -6:
                defending_starter.stat_stages[2] = -6
        elif move.name == 'withdraw':
            attacking_starter.stat_stages[2] = attacking_starter.stat_stages[2] + 1
            if attacking_starter.stat_stages[2] > 6:
                attacking_starter.stat_stages[2] = 6
    else:
        damage = round(
            (2 * attacking_starter.level / 5 + 2) * move.power * (attacking_starter.stats_effective[1] / defending_starter.stats_effective[2]) / 50
        )
        print(damage)
        defending_starter.cur_hp = defending_starter.cur_hp - damage

    return attacking_starter, defending_starter

## Tree Implementation

Bellow is a class implementing the tree I will be using for the MCTS algorithm.

In [7]:
# from poke_battle_sim import simulate_move  # Assuming this handles move mechanics

class BattleState(Node):
    def __init__(self, player_pokemon, rival_pokemon, is_player_turn):
        self.player_pokemon = player_pokemon
        self.rival_pokemon = rival_pokemon
        self.is_player_turn = is_player_turn

    def find_children(self):
        if self.is_terminal():
            return set()  # No further moves if the game is over

        active_pokemon = self.player_pokemon if self.is_player_turn else self.rival_pokemon
        return {
            self.apply_move(move)
            for move in active_pokemon.moves
        }

    def find_random_child(self):
        if self.is_terminal():
            return None  # No further moves if the game is over

        active_pokemon = self.player_pokemon if self.is_player_turn else self.rival_pokemon
        random_move = random.choice(active_pokemon.moves)
        return self.apply_move(random_move)

    def reward(self):
        if not self.is_terminal():
            raise RuntimeError("Reward called on non-terminal state.")
        if self.rival_pokemon.cur_hp <= 0:
            return 1  # Player wins
        if self.player_pokemon.cur_hp <= 0:
            return 0  # Rival wins
        return 0

    def is_terminal(self):
        return self.player_pokemon.cur_hp <= 0 or self.rival_pokemon.cur_hp <= 0

    def apply_move(self, move):
        # Simulate the move
        player_pokemon = self.player_pokemon
        rival_pokemon = self.rival_pokemon

        if self.is_player_turn:
            player_pokemon, rival_pokemon = simulate_move(player_pokemon, rival_pokemon, move)
        else:
            rival_pokemon, player_pokemon = simulate_move(rival_pokemon, player_pokemon, move)

        return BattleState(player_pokemon, rival_pokemon, not self.is_player_turn)

    def __hash__(self):
        return hash((self.player_pokemon.name, self.rival_pokemon.name, self.is_player_turn))

    def __eq__(self, other):
        return (
            self.player_pokemon.name == other.player_pokemon.name
            and self.rival_pokemon.name == other.rival_pokemon.name
            and self.is_player_turn == other.is_player_turn
        )

In [10]:
from services.starter_pokemons import get_random_starter, get_rival_starter

non_effective_nature = 'hardy'

# Initialize player and rival starters
player_starter = get_random_starter(non_effective_nature)  # Example: Player chooses Turtwig
rival_starter = get_rival_starter(player_starter.name, non_effective_nature)

# Initialize MCTS and the starting state
mcts = MCTS()
initial_state = BattleState(player_starter, rival_starter, True)
print(initial_state.player_pokemon.cur_hp)
print(initial_state.rival_pokemon.cur_hp)

# Perform MCTS rollouts
for _ in range(5):  # You can adjust the number of rollouts
    mcts.do_rollout(initial_state)

print(initial_state.player_pokemon.cur_hp)
print(initial_state.rival_pokemon.cur_hp)
print(initial_state.is_terminal())
# Choose the best move based on MCTS
best_move_state = mcts.choose(initial_state)

# Display the chosen move and updated battle state
print(f"Best move chosen for player: {best_move_state}")

15
14
Current path [<__main__.BattleState object at 0x0000020DEE972F00>]
3
3
3
3
3
3
3
3
3
Current path [<__main__.BattleState object at 0x0000020DEE972F00>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>, <__main__.BattleState object at 0x0000020DEE972EA0>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>, <__main__.BattleState object at 0x0000020DEE972EA0>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>]
Current path [<__main__.BattleState object at 0x0000020DEE972F00>, <__main__.BattleState object at 0x0000020DEE972EA0>]
0
2
True


RuntimeError: choose called on terminal node <__main__.BattleState object at 0x0000020DEE972F00>

In [17]:
print(best_move_state.player_pokemon.cur_hp)
print(best_move_state.rival_pokemon.cur_hp)
print(best_move_state.is_player_turn)

5
-1
False
False


So I just found a massive bug in my implementation and I dont know how I did not realise this earlier, but their is no such thing as `is_player_turn` in pokemon, as both the agent and the opponent can make moves at the same time. I will need to rethink my approach to this problem as I do not have the time to completly implement a model based approach. The reason will be explaiend in the conclusion.

## Conclusion

It seems to me that building a transition model (i.e. using a model based approach) is most viable when it is easy (i.e. wont take to long) to exhaustively turn all the rules and actions of a problem into code. For example, chess has pretty simple and relativly small set of rules. Their are only so many moves a piece can make, and their are not that many pieces. They have very predictable behavior, which makes it easy to implement a transition model.

Pokemon on the other hand, has 493 pieces and 215 unique move effects. Not to mention the fact that:
-  a trainer can have items to its disposale. 
-  the battle field can vary in state (i.e. weather, terrain, etc)
-  the battle field can have different rules (i.e. double battle, triple battle, etc)
-  pokemons can level up and learn new moves

This makes it very hard to implement a complete transition model for pokemon. Utilizing other peoples work (like for example, using the effect methods from `poke_battle_sim.util.process_move`) could make implementing a complete transition model easier. But given the time constraints of this project, I do not have the time to implement a complete transition model.

Then their is the issue that model based approaches are more memory exhaustive, as storing large rollouts of the transition model can take up a lot of memory. This problem could be negated by utilizing smart usage of persistant storage to for examle store explored states and retrieving those. 

Transition models also seem more explanitory, as it is more clear how for example MCTS evaluates a path then when using a DQN for example. The evaluation of a MCTS is also more deterministic then a DQN, as the DQN uses a neural network to approximate the value of a state.

A model based approach also requires less iterations to converge then a model free approach. This is because a model based approach can use the transition model to evaluate a path, while a model free approach has to rely on random exploration to converge towards a optimal policy.

To conclude with a comparison between model-based and model-free approaches:
> A model based approach is most viable when a problem has simple rules and pieces at play (like with chess), as that would make it easier to implement a complete transition model.   
> If the problem has complex rules and pieces at play and a implementation is already provided (like with pokemon or other games), a model free approach is more viable.

In [13]:
raise MemoryError("Nope!")

MemoryError: Nope!