<a href="https://colab.research.google.com/github/MIT-RIR-AI-Course/additional_code_resource/blob/master/mcts_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Dependencies

In [1]:
!pip install GraphvizAnim

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting GraphvizAnim
  Downloading GraphvizAnim-1.1.0-py3-none-any.whl (20 kB)
Installing collected packages: GraphvizAnim
Successfully installed GraphvizAnim-1.1.0


# Utilities

In [2]:

from typing import (Callable, Iterable, List, Sequence, Tuple, Dict, Optional,
                    Any, Union)

from abc import abstractmethod
import collections
import itertools
import functools
import random

import numpy as np

########## Graph-Search-Related Utilities and Class Definitions ##########

State = Any
Action = Any

StateSeq = List[State]
ActionSeq = List[State]
CostSeq = RewardSeq = List[float]


class Problem(object):
    """The abstract base class for either a path cost problem or a reward problem."""

    def __init__(self, initial: State):
        self.initial = initial

    @abstractmethod
    def actions(self, state: State) -> Iterable[Action]:
        """Returns the allowed actions in a given state. 
    
    The result would typically be a list. But if there are many actions, 
    consider yielding them one at a time in an iterator, 
    rather than building them all at once.
    """
        ...

    @abstractmethod
    def step(self, state: State, action: Action) -> State:
        """Returns the next state when executing a given action in a given state. 
    
    The action must be one of self.actions(state).
    """
        ...


class RewardProblem(Problem):
    """An abstract class for a finite-horizon reward problem, based on AIMA.
  
  To formalize a reward problem, you should subclass from this and implement 
  the abstract methods. 
  Then you will create instances of your subclass and solve them with the 
  various search functions.
  """

    def __init__(self, initial: State, horizon: int):
        self.initial = initial
        self.horizon = horizon

    @abstractmethod
    def reward(self, state1: State, action: Action, state2: State) -> float:
        """Returns the reward given at state2 from state1 via action.
    
    A reward at each step must be no greater than `self.rmax`.
    """
        ...

    @property
    @abstractmethod
    def rmax(self) -> float:
        """Returns the maximum reward per step."""
        ...


class ProblemSession01GridProblem(RewardProblem):

    def __init__(self):
        super().__init__((1, 1), 2)
        self.trans = {
            (0, 0): {
                "U": (0, 1)
            },
            (0, 1): {
                "U": (0, 2),
                "R": (1, 1),
                "D": (0, 0)
            },
            (0, 2): {
                "D": (0, 1)
            },
            (1, 1): {
                "L": (0, 1),
                "R": (2, 1)
            },
            (2, 0): {
                "U": (2, 1)
            },
            (2, 1): {
                "U": (2, 2),
                "L": (1, 1),
                "D": (2, 0)
            },
            (2, 2): {
                "D": (2, 1)
            }
        }

    def actions(self, state: State) -> Iterable[Action]:
        return self.trans[state].keys()

    def step(self, state: State, action: Action) -> State:
        return self.trans[state][action]

    def reward(self, state1: State, action: Action, state2: State) -> float:
        return {
            (0, 0): 1.0,
            (0, 1): 0.5,
            (0, 2): 0.1,
            (1, 1): 0.,
            (2, 0): 0.9,
            (2, 1): 0.6,
            (2, 2): 0.9,
        }[state2]


# MCTS with Visualization

In [3]:
from typing import ClassVar
import dataclasses

from gvanim import Animation
from gvanim.jupyter import interactive

@dataclasses.dataclass(frozen=False, eq=False)
class MCTNode:
    """Node in the Monte Carlo search tree, keeps track of the children states."""
    state: State
    U: float
    N: int
    horizon: int
    parent: Optional['MCTNode'] = dataclasses.field(repr=False)
    children: Dict['MCTNode', Action] = dataclasses.field(default_factory=dict)

    name: int = dataclasses.field(init=False)
    counter: ClassVar[int] = 0

    def __post_init__(self):
      self.name = MCTNode.counter
      MCTNode.counter += 1

def ucb(n: MCTNode, C: float = 1.4) -> float:
    """UCB for a node, note the C argument"""
    return np.inf if n.N == 0 else (n.U / n.N +
                                    C * np.sqrt(np.log(n.parent.N) / n.N))


def run_mcts_search(anime: Animation, 
                    titles: List[str],
                    problem: RewardProblem,
                    C: float = 1.4,
                    iteration_budget: int = 1000,
                    step_budget: int = np.inf):
    """A generic MCTS search implementation.
    
    Args:
      problem: a reward problem.
      C: the UCB parameter.
      iteration_budget: maximum iterations to run the search.
      step_budget: maximum number of allowed `problem.step`s.

    Returns:
      state_sequence: A list of states.
      action_sequence: A list of actions.
      reward_sequence: A list of rewards
    """
    if min(iteration_budget, step_budget) == np.inf:
        raise ValueError("Must provide at least one budget")

    problem_step_count = 0

    class BudgetExceeded(Exception):
        pass

    def step_helper(state, action):
        """helper to track the problem's step count."""
        nonlocal problem_step_count
        problem_step_count += 1
        if problem_step_count > step_budget:
            raise BudgetExceeded("step budget exceeded")
        return problem.step(state, action)

    ucb_fixed_C = functools.partial(ucb, C=C)

    selections = []
    backup_node = None
    def highlight_chain(nodes, color, ignore_first_two=False):
      for i, n in enumerate(nodes):
        if ignore_first_two and i < 2:
          continue
        if n != backup_node:
          anime.highlight_node(n, color=color)
      for n1, n2 in zip(nodes[:-1], nodes[1:]):
        anime.highlight_edge(n1, n2, color=color)

    def next_frame(title=""):
      highlight_chain(selections, "blue")
      highlight_chain(simulations, "gray", ignore_first_two=True)
      anime.next_step()
      titles.append(title)

    def select(n: MCTNode) -> MCTNode:
        """select a leaf node in the tree"""
        selections.append(n.name)
        next_frame("Choose node by UCB")
        if n.children:
            ucb_pick = max(n.children.keys(), key=ucb_fixed_C)
            for c in n.children:
              anime.label_node(c.name, f"State={c.state}\nU={c.U}\nN={c.N}\nUCB={ucb_fixed_C(c):.2f}")
              anime.highlight_node(c.name, color="deepskyblue")
            next_frame("Compute UCB of children")
            for c in n.children:
              anime.label_node(c.name, f"State={c.state}\nU={c.U}\nN={c.N}")
            selection = select(ucb_pick)
        else:
          selection = n
        return selection

    def expand(n: MCTNode) -> MCTNode:
        """expand the leaf node by adding all its children states"""
        assert not n.children
        if n.horizon == 0:
            return n
        for action in problem.actions(n.state):
            child_state = step_helper(n.state, action)
            new_node = MCTNode(state=child_state,
                               horizon=n.horizon - 1,
                               parent=n,
                               U=0,
                               N=0)
            n.children[new_node] = action
            anime.add_node(new_node.name)
            anime.label_node(new_node.name, f"State={new_node.state}\nU={new_node.U}\nN={new_node.N}")
            anime.highlight_node(new_node.name, color="green")
            anime.add_edge(n.name, new_node.name)
            anime.label_edge(n.name, new_node.name, label=f"{action}")
        next_frame("Expand new children")

        child = random.choice(list(n.children.keys()))
        simulations.append(n.name)
        simulations.append(child.name)
        next_frame("Simulate random playout")
        return child

    simulations = []
    def simulate(node: MCTNode) -> float:
        """simulate the utility of current state by randomly picking a step"""
        current_anim_node = node.name
        state = node.state
        total_reward = 0
        for h in range(node.horizon, 0, -1):
            action = random.choice(problem.actions(state))
            child_state = step_helper(state, action)
            reward = problem.reward(state, action, child_state)
            total_reward += reward

            new_anim_node = hash(state)
            simulations.append(new_anim_node)
            anime.add_node(new_anim_node)
            anime.add_edge(current_anim_node, new_anim_node)
            reward_str = f"+{reward}" if reward >= 0 else str(reward)
            anime.label_edge(current_anim_node, new_anim_node, label=f"{action}, {reward_str}")
            anime.label_node(new_anim_node, label=f"State={child_state}")
            next_frame("Simulate random playout")
            current_anim_node = new_anim_node

            state = child_state
        return total_reward

    def backup(n: MCTNode, value: float) -> None:
        """passing the utility back to all parent nodes"""
        nonlocal backup_node
        if n.parent:
            # Need to include the reward on the action *into* n
            a = n.parent.children[n]
            r = problem.reward(n.parent.state, a, n.state)
            n.U += value + r
            n.N += 1
            anime.label_node(n.name, f"State={n.state}\nU={n.U}\nN={n.N}")
            anime.highlight_node(n.name, color='orange')
            reward_str = f"+{r}" if r >= 0 else str(r)
            anime.label_edge(n.parent.name, n.name, f"{a}, {reward_str}")
            backup_node = n.name
            next_frame("Backup update")
            anime.label_edge(n.parent.name, n.name, f"{a}")
            backup(n.parent, value + r)
        else:
            n.N += 1
            anime.label_node(n.name, f"State={n.state}\nU=-\nN={n.N}")
            anime.highlight_node(n.name, color='orange')
            backup_node = n.name
            next_frame("Backup update")

    root = MCTNode(state=problem.initial,
                   horizon=problem.horizon,
                   parent=None,
                   U=0,
                   N=0)
    anime.add_node(root.name)
    anime.label_node(root.name, f"State={root.state}\nU=-\nN={root.N}")
    next_frame("Create the root")

    try:
        i = 0
        while i < iteration_budget:
            leaf = select(root)
            child = expand(leaf)
            value = simulate(child)
            backup(child, value)
            i += 1

            # End backup, clear up the simulations
            selections, backup_node = [], None    
            for n in simulations[2:]:
              anime.remove_node(n)
            simulations.clear()
            next_frame(f"Begin iteration {i}")

    except BudgetExceeded:
        pass

    return finish_mcts_plan(problem, root), root


def finish_mcts_plan(problem: RewardProblem, node: MCTNode):
    """Helper for run_mcts_search. Recover the plan. """
    state_sequence = [node.state]
    action_sequence = []
    reward_sequence = []

    while node.children:
        max_node = max(node.children, key=lambda p: p.N)
        max_action = node.children.get(max_node)
        action_sequence.append(max_action)
        state_sequence.append(max_node.state)
        reward_sequence.append(
            problem.reward(node.state, max_action, max_node.state))
        node = max_node

    return state_sequence, action_sequence, reward_sequence


In [6]:
import os
import shutil
import tempfile
from unittest.mock import patch
import ipywidgets as widgets
from PIL import Image
from gvanim import render
from gvanim.jupyter import interactive

def interactive(animation, titles, size=320):
  basedir = tempfile.mkdtemp()
  basename = os.path.join(basedir, 'graph')
  graphs = animation.graphs()
  graphs = graphs[len(graphs) // 2:-1]
  paths = render(graphs, basename, 'png', size)
  max_size = tuple(np.max(np.array([list(Image.open(path).size) for path in paths]), axis=0))
  steps = []
  for path in paths:
    im = Image.open(path)
    resized_im = Image.new('RGB', max_size, (255, 255, 255))  # White
    resized_im.paste(im, None)
    steps.append(im)
  slider = widgets.IntSlider(min = 0, max = len(steps) - 1, step = 1, value = 0)
  shutil.rmtree(basedir)
  def update_view(n):
     print(titles[n])
     display(steps[n])
  return widgets.interactive(update_view, n = slider)

def dfv(node):
  yield node
  for c in node.children:
    yield from dfv(c)

def visualize_mcts(problem: RewardProblem, 
                   random_order: List[int], 
                   C: float = 1.4,
                   iteration_budget=8):
  """Visualizes a MCTS running on a reward problem.
  
  Args:
    problem: a reward problem
    random_order: the source of randomness when running simulation. 
    iteration_budget: iteration budget for MCTS. 
      Must be small enough such that `random_order` does not run out.
  """
  def patched_choice(order):
    """Patch helper to `random.choice` with a fixed selection."""
    order = iter(order)
    def choice(seq):
        return list(seq)[next(order)]
    return choice
  
  anime = Animation()
  with patch("random.choice", patched_choice(random_order)):
      # First run to build the graph
      MCTNode.counter = 0
      plan, root = run_mcts_search(anime, [], problem, C=1.0, iteration_budget=iteration_budget)
      for n in dfv(root):
        anime.remove_node(n.name)
      anime.next_step()
  anime.next_step()
  with patch("random.choice", patched_choice(random_order)):
      # Second run to visualize
      MCTNode.counter = 0
      titles = ["Begin iteration 0"]
      plan = run_mcts_search(anime, titles, problem, C=C, iteration_budget=iteration_budget)
  return interactive(anime, titles, size=800)


# Visualization


In [7]:
problem = ProblemSession01GridProblem()
# Visualize Q1 from problem session
# To step through the visualization, click the slider to activate it then use the keyboard's arrow keys.
visualize_mcts(problem, random_order=[0, 2, 1, 2], C=1.0, iteration_budget=5)

interactive(children=(IntSlider(value=0, description='n', max=44), Output()), _dom_classes=('widget-interact',…

In [None]:
problem = ProblemSession01GridProblem()
# Visualize Q2 from problem session
# To step through the visualization, click the slider to activate it then use the keyboard's arrow keys.
visualize_mcts(problem, random_order=[0, 0, 1, 1], C=1.0, iteration_budget=7)

interactive(children=(IntSlider(value=0, description='n', max=62), Output()), _dom_classes=('widget-interact',…