Student: Dorin Doncenco

Todo: Learn MCTSA maybe read some youtube vids

# TP 3 - Planning (November 30)


![ChessUrl](https://gymnasium.farama.org/_images/frozen_lake.gif "Frozen Lake")

In this assignement, we focus on algorithms that require a **model** of the environment behavior. You will implement :

- A Monte Carlo Tree Search Algorithm
- A Tabular Dyna-Q Algorithm

You will be evaluated on:
* Implementation of the agents. Points will be granted to clean, scalable code.
* A Paragraph of analysis of the behavior of the algorithms . 

Send this notebook  to cyriaque.rousselot@inria.fr before next course.


In [1]:
%load_ext autoreload
%autoreload 1
%aimport utils

## Environment

### Snapshots

For the sake of planning algorithm, we will introduce the possibility of taking snapshots of the environment. Snapshots allows to return to a previously visited state.

In [2]:
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

In [3]:
env = utils.WithSnapshots(gym.make("FrozenLake-v1",map_name="8x8",
                             render_mode="ansi",
                             max_episode_steps=200))
env.reset()
n_actions = env.action_space.n
n_states = env.observation_space.n

In [4]:
import matplotlib.pyplot as plt
print("initial_state:")
print(env.render())
# plt.axis('off')
env.close()

# create first snapshot
snap0 = env.get_snapshot()

initial_state:

[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG



In [5]:
while True:
    _, _, terminated, truncated, _ = env.step(env.action_space.sample())
    if terminated:
        print("Whoops! We died!")
        break
    if truncated:
        print("Time is over!")
        break

print("final state:")
print(env.render())
env.close()

Whoops! We died!
final state:
  (Right)
SFFFFFFF
FFFFFFFF
FFF[41mH[0mFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG



In [6]:
# reload initial state
env.load_snapshot(snap0)

print("After loading snapshot")
print(env.render())
env.close()

After loading snapshot

[41mS[0mFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFFG



## Monte Carlo Tree Search

https://en.wikipedia.org/wiki/Monte_Carlo_tree_search ; Sutton-Barto Chapter 8.11

The MCTS algorithm we will implement can be divided in 4 steps:
- Selection 
- Expansion
- Simulation
- Backpropagation

The first step is exploring the current tree using a UCB-1 rule until we get to a leaf L .

The second is creating a child C from feasable moves after the leaf L if the game is not finished.

The third is simulating the end of the game with an unbiased method to get an estimate of the value of the position C.

The fourth is updating the value estimation of the position of all nodes visited during the exploration of the tree.



![image.png](https://i.postimg.cc/6QmwnjPS/image.png)

We will use snapshots to simulate the effect of a sample model:
1. Saving a snapshot of state S
2. sending S,A to the environement
3. Getting back R and S'
4. When needed, loading the snapshot of state S

## Building the agent:

> Fill the blanks in the code

In [7]:
class Node:
    """A node in the Monte Carlo Tree Search (MCTS) algorithm."""
    
    #metadata:
    parent = None          #parent Node
    qvalue_sum = 0.         #sum of state values from all visits (numerator)
    times_visited = 0      #counter of visits (denominator)

    def __init__(self, parent, action):
        """
        Initializes a tree node with a parent, action, and environment.

        :param parent: parent TreeNode
        :param action: action to commit from parent Node
        """

        self.parent = parent
        self.action = action
        self.children = set()

        # Capture the outcome after performing the action in the parent's state
        result = env.get_result(parent.snapshot, action)
        (
            self.snapshot,
            self.observation,
            self.immediate_reward,
            self.is_done,
            _,
        ) = result

    def is_leaf(self):
        return not self.children

    def is_root(self):
        return self.parent is None

    def get_qvalue_estimate(self):
        if self.times_visited !=0:
            return self.qvalue_sum / self.times_visited
        return 0

    def ucb_score(self, scale=10, max_value=float("inf")):
        """
        Computes the Upper Confidence Bound (UCB) score for the node.

        :param scale: Multiplies the upper bound by this value. Assumes reward range to be [0, scale].
        :param max_value: a value representing infinity (for unvisited nodes).
        """
        if self.times_visited == 0:
            return max_value

        return self.get_qvalue_estimate() + scale * np.sqrt(2*np.log(self.parent.times_visited) / self.times_visited)

    # MCTS steps

    def select_best_leaf(self):
        """
        Selects the leaf with the highest priority to expand.

        Recursively picks nodes with the best UCB score until it reaches a leaf.
        """
        # Using the UCB valuation, select the best possible child among children of a node
        if self.is_leaf():
            return self
        children = self.children
        
        best_child = max(children, key=lambda child: child.ucb_score())
        return best_child.select_best_leaf()

    def expand(self):
        """
        Expands the current node by creating all possible child nodes.

        Returns one of those children.
        """
        # You can't generate a child if there is already an existing child with the same associated action.

        assert not self.is_done, "Can't expand from terminal state"

        # Create a new child node and add it to the current node's children set
        for action in range(n_actions):
            # Check if the action is already in the children set
            for child in self.children:
                if child.action == action:
                    continue
            child = Node(self, action)
            self.children.add(child)
        return self.select_best_leaf()

    def rollout(self, t_max=10**4):
        """
        Plays the game from this state to the end (done) or for t_max steps.

        On each step, picks an action at random.

        Computes the sum of rewards from the current state until the end of the episode.

        If the node is terminal, return the immediate reward
        """

        env.load_snapshot(self.snapshot)
        obs = self.observation
        is_done = self.is_done
        rollout_reward = 0
        while not is_done and t_max>0:
            action = env.action_space.sample()
            # check env step return docs
            obs, reward, is_done, truncated, _ = env.step(action)
            rollout_reward += reward
            t_max -= 1
            


        return rollout_reward

    def propagate(self, child_qvalue):
        """
        Uses the child Q-value to update parents number of visits and qvalue recursively.
        """
        my_qvalue = self.immediate_reward + child_qvalue

        # Update qvalue_sum and times_visited
        self.qvalue_sum += my_qvalue
        self.times_visited += 1

        # Propagate upwards
        if not self.is_root():
            self.parent.propagate(my_qvalue)

            
    def safe_delete(self):
        """safe delete to prevent memory leak in some python versions"""
        del self.parent
        for child in self.children:
            child.safe_delete()
            del child


In [8]:
class Root(Node):
    """The root node"""

    def __init__(self, snapshot, observation):
        self.parent = self.action = None
        self.children = set()
        self.snapshot = snapshot
        self.observation = observation
        self.immediate_reward = 0
        self.is_done = False

    @staticmethod
    def from_node(node):
        root = Root(node.snapshot, node.observation)
        # Copy data
        copied_fields = ["qvalue_sum", "times_visited", "children", "is_done"]
        for field in copied_fields:
            setattr(root, field, getattr(node, field))
        return root


### Running the MCTS 

In [9]:
def plan_mcts(root, n_iters=10):
    """
    Builds a tree with Monte-Carlo Tree Search for n_iters iterations.
    :param root: Tree node to plan from.
    :param n_iters: Number of select-expand-simulate-propagate loops to make.
    """
    for _ in range(n_iters):
        node = root.select_best_leaf()

        if node.is_done:
            # All rollouts from a terminal node are empty, and thus have 0 reward.
            node.propagate(0)
        else:
            # Expand the best leaf, perform a rollout from it, and propagate the results upwards.
            node = node.expand()
            reward = node.rollout()
            node.propagate(reward)  

In [10]:
env = utils.WithSnapshots(gym.make("FrozenLake-v1",map_name="8x8",
                             render_mode="ansi",
                             max_episode_steps=200))
root_observation = env.reset()
root_snapshot = env.get_snapshot()
root = Root(root_snapshot, root_observation)

> Use the MCTS implementation to find the optimal policy and show it. Bonus point will be given to a clear display

In [147]:
plan_mcts(root, n_iters=4000)

In [184]:
id_to_action = {
    None: "ORIGIN",
    0: " LEFT ",
    1: " DOWN ",
    2: "RIGHT ",
    3: "  UP  ",
}

In [185]:
# display tree
def display_tree(node, depth=0, max_depth=3):
    if depth > max_depth:
        return
    prefix = "  " * depth + "|%d|" %depth
    print(prefix, "Q = %.3f" %node.get_qvalue_estimate(), "A = ", id_to_action[node.action], "N =", node.times_visited)
    for child in node.children:
        display_tree(child, depth + 1, max_depth)

def display_tree_max_only(node, depth=0, max_depth=3):
    if depth > max_depth:
        return
    prefix = "  " * depth + "|%d|" %depth
    # find best child
    best_child = max(node.children, key=lambda child: child.get_qvalue_estimate())
    print(prefix, "Q = %.3f" %node.get_qvalue_estimate(), "A = ", id_to_action[node.action], "N =", node.times_visited)
    if not(best_child.is_leaf()):
        display_tree_max_only(best_child, depth + 1, max_depth)

def display_tree_non_zeros(node, depth=0, max_depth=3):
    if depth > max_depth:
        return
    prefix = "  " * depth + "|%d|" %depth
    if node.get_qvalue_estimate() != 0:
        print(prefix, "Q = %.3f" %node.get_qvalue_estimate(), "A = ", id_to_action[node.action], "N =", node.times_visited)
    for child in node.children:
        display_tree_non_zeros(child, depth + 1, max_depth)

In [186]:
display_tree(root, max_depth=3)


|0| Q = 0.002 A =  ORIGIN N = 7000
  |1| Q = 0.002 A =   LEFT  N = 1750
    |2| Q = 0.002 A =   LEFT  N = 438
      |3| Q = 0.009 A =   DOWN  N = 110
      |3| Q = 0.000 A =    UP   N = 109
      |3| Q = 0.000 A =   LEFT  N = 109
      |3| Q = 0.000 A =  RIGHT  N = 109
    |2| Q = 0.002 A =   DOWN  N = 438
      |3| Q = 0.000 A =    UP   N = 110
      |3| Q = 0.000 A =   LEFT  N = 109
      |3| Q = 0.009 A =  RIGHT  N = 110
      |3| Q = 0.000 A =   DOWN  N = 109
    |2| Q = 0.000 A =  RIGHT  N = 436
      |3| Q = 0.000 A =   LEFT  N = 109
      |3| Q = 0.000 A =    UP   N = 109
      |3| Q = 0.000 A =  RIGHT  N = 109
      |3| Q = 0.000 A =   DOWN  N = 109
    |2| Q = 0.002 A =    UP   N = 437
      |3| Q = 0.009 A =   DOWN  N = 110
      |3| Q = 0.000 A =   LEFT  N = 109
      |3| Q = 0.000 A =    UP   N = 109
      |3| Q = 0.000 A =  RIGHT  N = 109
  |1| Q = 0.001 A =    UP   N = 1748
    |2| Q = 0.000 A =   LEFT  N = 436
      |3| Q = 0.000 A =   LEFT  N = 109
      |3| Q = 0.000 A

In [187]:
display_tree_max_only(root, max_depth=10)

|0| Q = 0.002 A =  ORIGIN N = 7000
  |1| Q = 0.003 A =   DOWN  N = 1753
    |2| Q = 0.005 A =   DOWN  N = 439
      |3| Q = 0.018 A =   DOWN  N = 111
        |4| Q = 0.071 A =  RIGHT  N = 28
          |5| Q = 0.286 A =    UP   N = 7
            |6| Q = 1.000 A =  RIGHT  N = 2


In [188]:
display_tree_non_zeros(root, max_depth=10)

|0| Q = 0.002 A =  ORIGIN N = 7000
  |1| Q = 0.002 A =   LEFT  N = 1750
    |2| Q = 0.002 A =   LEFT  N = 438
      |3| Q = 0.009 A =   DOWN  N = 110
        |4| Q = 0.036 A =  RIGHT  N = 28
          |5| Q = 0.143 A =    UP   N = 7
            |6| Q = 0.500 A =   LEFT  N = 2
              |7| Q = 1.000 A =    UP   N = 1
    |2| Q = 0.002 A =   DOWN  N = 438
      |3| Q = 0.009 A =  RIGHT  N = 110
        |4| Q = 0.036 A =   LEFT  N = 28
          |5| Q = 0.143 A =    UP   N = 7
            |6| Q = 0.500 A =  RIGHT  N = 2
              |7| Q = 1.000 A =  RIGHT  N = 1
    |2| Q = 0.002 A =    UP   N = 437
      |3| Q = 0.009 A =   DOWN  N = 110
        |4| Q = 0.036 A =    UP   N = 28
          |5| Q = 0.143 A =   LEFT  N = 7
            |6| Q = 0.500 A =   DOWN  N = 2
              |7| Q = 1.000 A =   DOWN  N = 1
  |1| Q = 0.001 A =    UP   N = 1748
    |2| Q = 0.002 A =   DOWN  N = 438
      |3| Q = 0.009 A =    UP   N = 110
        |4| Q = 0.036 A =   LEFT  N = 28
          |5| Q = 0

> Try it also on the Cartpole problem:

In [None]:
env = utils.WithSnapshots(gym.make("CartPole-v1", render_mode="rgb_array", max_episode_steps=200))

## (BONUS) Introducing some Machine Learning

Planning on each iteration can be costly. You can speed things up drastically if you train a classifier to predict which action will turn out to be best according to MCTS.

>To do so, adapt the code and record which action did the MCTS agent take on each step and fit a classifier to [state, mcts_optimal_action]

# Model Learning:  Dyna-Q

> Implement a Tabular Dyna-Q algorithm ( Chapter 8.2 Barto-Sutton) for the Frozen Lake environment

!["Description of Dyna Algorithm"](dyna.png)

In [None]:
class DynaAgent:
    def __init__(self,env,epsilon=1e-3):
        """Step (a)"""
        self.n_actions = env.action_space.n
        self.n_states = env.observation_space.n
        self.epsilon =epsilon
        self.q = np.zeros((n_states,n_actions))
        self.model = np.zeros((n_states,n_actions,2)) # self.model[s,a] return r and s'
        self.env = env
        self.current_state,_ = env.reset()
         
    def choose(self):
        """Step (b)"""
        raise NotImplementedError

    def observe(self,action):
        """Step (c)"""
        raise NotImplementedError
    
    def update_model(self,s1,a1,r1,s2):
        """Step (e)"""
        raise NotImplementedError

    def update_value(self,s1,a1,r1,s2):
        """Step (d)"""
        raise NotImplementedError

    def planning(self,n_steps):
        """Step (f)"""
        raise NotImplementedError

> What are some limits of the algorithm ? Does it scale ? Explain. 

_Parts of the code for this practical has been inspired by https://github.com/yandexdataschool/Practical_RL/_