# Homework 1

## Imports and Utilities
**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:

from typing import (Callable, Iterable, List, Sequence, Tuple, Dict, Optional,
                    Any, Union)
from abc import abstractmethod

from collections import namedtuple
import signal
import itertools
import functools
import random
import contextlib

import numpy as np

########## Graph-Search-Related Utilities and Class Definitions ##########

State = Any
Action = Any

StateSeq = List[State]
ActionSeq = List[State]
CostSeq = RewardSeq = List[float]


class Problem(object):
  """The abstract base class for either a path cost problem or a reward problem."""

  def __init__(self, initial: State):
    self.initial = initial

  @abstractmethod
  def actions(self, state: State) -> Iterable[Action]:
    """Returns the allowed actions in a given state. 

    The result would typically be a list. But if there are many actions, 
    consider yielding them one at a time in an iterator, 
    rather than building them all at once.
    """
    ...

  @abstractmethod
  def step(self, state: State, action: Action) -> State:
    """Returns the next state when executing a given action in a given state. 

    The action must be one of self.actions(state).
    """
    ...


class PathCostProblem(Problem):
  """An abstract class for a path cost problem, based on AIMA.

  To formalize a path cost problem, you should subclass from this and implement 
  the abstract methods. 
  Then you will create instances of your subclass and solve them with the 
  various search functions.
  """

  @abstractmethod
  def goal_test(self, state: State) -> bool:
    """Checks if the state is a goal."""
    ...

  @abstractmethod
  def step_cost(self, state1: State, action: Action, state2: State) -> float:
    """Returns the cost incurred at state2 from state1 via action."""
    ...

  def h(self, state: State) -> float:
    """Returns the heuristic value, a lower bound on the distance to goal."""
    return -np.inf


class RewardProblem(Problem):
  """An abstract class for a finite-horizon reward problem, based on AIMA.

  To formalize a reward problem, you should subclass from this and implement 
  the abstract methods. 
  Then you will create instances of your subclass and solve them with the 
  various search functions.
  """

  def __init__(self, initial: State, horizon: int):
    self.initial = initial
    self.horizon = horizon

  @abstractmethod
  def reward(self, state1: State, action: Action, state2: State) -> float:
    """Returns the reward given at state2 from state1 via action.

    A reward at each step must be no greater than `self.rmax`.
    """
    ...

  @property
  @abstractmethod
  def rmax(self) -> float:
    """Returns the maximum reward per step."""
    ...


class GridProblem(PathCostProblem):
  """A grid problem."""

  def __init__(self, initial, goal=(4, 4)):
    super().__init__(initial)
    self.goal = goal
    self.all_grid_actions = ["up", "down", "left", "right"]
    self.grid_act_to_delta = {
        "up": (-1, 0),
        "down": (1, 0),
        "left": (0, -1),
        "right": (0, 1)
    }
    # Somewhat unusual cost structure, depends on s', which is determined by s,a
    self.grid_arrival_costs = np.array(
        [
            [1, 1, 8, 1, 1],
            [1, 8, 1, 1, 1],
            [1, 8, 1, 1, 1],
            [1, 1, 1, 8, 1],
            [1, 1, 2, 1, 1],
        ],
        dtype=int,
    )

  def actions(self, state):
    (r, c) = state
    actions = []
    for act in self.all_grid_actions:
      dr, dc = self.grid_act_to_delta[act]
      new_r, new_c = r + dr, c + dc
      # Check if in bounds
      if (0 <= new_r < self.grid_arrival_costs.shape[0] and
          0 <= new_c < self.grid_arrival_costs.shape[1]):
        actions.append(act)
    return actions

  def step(self, state, action):
    (r, c) = state
    dr, dc = self.grid_act_to_delta[action]
    return (r + dr, c + dc)

  def goal_test(self, state):
    return state == self.goal

  def step_cost(self, state1, action, state2):
    return self.grid_arrival_costs[state2]

  def h(self, state):
    """Manhattan distance."""
    return abs(state[0] - self.goal[0]) + abs(state[1] - self.goal[1])


@contextlib.contextmanager
def count_step_calls(problem: Problem):
  """Enforce that `problem.step`.

  Example:
    >>> problem = GridProblem()
    >>> with count_step_calls(problem) as counter:
    ...   problem.step((0, 0), "down")
    ...   problem.step((1, 1), "up")
    ...   assert counter[((0, 0), "down")] == counter[((1, 1), "up")]  == 1
    ...   assert counter["total"] == 2
  """
  counter = collections.Counter()

  def step_helper(state, action):
    counter[(state, action)] += 1
    counter["total"] += 1

  orig_problem_step = problem.step
  problem.step = step_helper
  try:
    yield problem
  finally:
    problem.step = orig_problem_step





## Best-first Search


### Utilities


**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:
# A useful data structure for heuristic search
Node = namedtuple("Node", ["state", "parent", "action", "cost", "g"])


class SearchFailed(ValueError):
  """Raise this exception whenever a search must fail."""
  pass


import heapq as hq


def get_trace_from_actions_helper(
    problem: PathCostProblem,
    actions: List[Action]) -> Tuple[List[State], List[float]]:
  """Get the planned states and costs given a list of planned actions.


  This is a helper function for debugging and running our tests.
  It invokes `problem.step` to reconstruct a sequence of states from actions.
  Do not invoke this function in a search algorithm! You should only
  call `problem.step` at most once for each state-action pair.

  Args:
    problem: a path cost problem
    actions: a list of actions.

  Returns:
    a list of states and a list of rewards/costs.
  """

  states = [problem.initial]
  costs = []
  for action in actions:
    states.append(problem.step(states[-1], action))
    costs.append(problem.step_cost(states[-2], action, states[-1]))
  return states, costs

### Question
Complete an implementation of the best-first search, encompassing A*, GBFS, or UCS. You can assume any heuristics are consistent.

For reference, our solution is **56** line(s) of code.

In [None]:

def run_best_first_search(
    problem: PathCostProblem,
    get_priority: Callable[[Node], float],
    max_steps: int = 1000) -> Tuple[StateSeq, ActionSeq, CostSeq]:
  """A generic heuristic search implementation.

  Depending on `get_priority`, can implement A*, GBFS, or UCS.

  The `get_priority` function here should determine the order
  in which nodes are expanded. For example, if you want to
  use path cost as part of this determination, then the
  path cost (node.g) should appear inside of get_priority,
  rather than in this implementation of `run_best_first_search`.

  Important: for determinism (and to make sure our tests pass),
  please break ties using the state itself. For example,
  if you would've otherwise sorted by `get_priority(node)`, you
  should now sort by `(get_priority(node), node.state)`.

  Args:
    problem: a path cost problem.
    get_priority: a callable taking in a search Node and returns the priority
    max_steps: maximum number of `problem.step` before giving up.

  Returns:
    state_sequence: A list of states.
    action_sequence: A list of actions.
    cost_sequence: A list of costs.
    num_steps: number of taken `problem.step`s. Must be less than or equal to `max_steps`.

  Raises:
    error: SearchFailed, if no plan is found.
  """
  raise SearchFailed("Implement me!")

#### Tests

In [None]:
# We will test this implementation more thoroughly with the
# specific heuristic search algorithms that follow
grid_problem = GridProblem((0, 0))
get_priority_fn = lambda node: 0
result = run_best_first_search(grid_problem, get_priority_fn)
assert len(result) == 3

def best_first_search_test2():
  # We will test this implementation more thoroughly with the
  # specific heuristic search algorithms that follow
  grid_problem = GridProblem((0, 0))
  get_priority_fn = lambda node: 0
  state_sequence, action_sequence, cost_sequence, num_expansions = run_best_first_search(
      grid_problem, get_priority_fn)
  # Textbook implementation
  try:
    assert state_sequence == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 4),
                              (2, 4), (3, 4), (4, 4)]
    assert action_sequence == [
        'right', 'right', 'right', 'right', 'down', 'down', 'down', 'down'
    ]
    assert cost_sequence == [1.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
  # Alternative implementation that tracks best-cost-to-nodes
  except AssertionError:
    assert state_sequence == [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
                              (4, 2), (4, 3), (4, 4)]
    assert action_sequence == [
        'down', 'down', 'down', 'right', 'right', 'down', 'right', 'right'
    ]
    assert cost_sequence == [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0]
  assert num_expansions <= 35

best_first_search_test2()

def best_first_search_test3():
  """If your results do not match the expected ones, make sure that you are tiebreaking
  as described in the docstring for `run_best_first_search`."""
  grid_problem = GridProblem((0, 0))
  get_priority_fn = lambda node: node.g
  state_sequence, action_sequence, cost_sequence, num_expansions = run_best_first_search(
      grid_problem, get_priority_fn)
  assert state_sequence == [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
                            (4, 2), (4, 3), (4, 4)]
  assert action_sequence == [
      'down', 'down', 'down', 'right', 'right', 'down', 'right', 'right'
  ]
  assert cost_sequence == [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0]
  assert num_expansions <= 35

best_first_search_test3()

print('Tests passed.')

## Uniform Cost Search


### Question
Use your implementation of `run_best_first_search` to implement uniform cost search.

For reference, our solution is **4** line(s) of code.

In addition to all of the utilities defined at the top of the colab notebook, the following functions are available in this question environment: `run_best_first_search`. You may not need to use all of them.

In [None]:

def run_uniform_cost_search(problem: PathCostProblem,
                            max_expansions: int = 1000):
  """Uniform-cost search.

  Use your implementation of `run_best_first_search`.
  """
  raise NotImplementedError("Implement me!")

#### Tests

In [None]:
def ucs_test1():
  # If your results do not match the expected ones, make sure that you are tiebreaking
  # as described in the docstring for `run_best_first_search`.
  grid_problem = GridProblem((0, 0))
  state_sequence, action_sequence, cost_sequence, num_expansions = run_uniform_cost_search(
      grid_problem)
  assert state_sequence == [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
                            (4, 2), (4, 3), (4, 4)]
  assert action_sequence == [
      'down', 'down', 'down', 'right', 'right', 'down', 'right', 'right'
  ]
  assert cost_sequence == [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0]
  assert num_expansions <= 30

ucs_test1()

print('Tests passed.')

## A* Search


### Question
Use your implementation of `run_best_first_search` to implement A* search.

For reference, our solution is **4** line(s) of code.

In addition to all of the utilities defined at the top of the colab notebook, the following functions are available in this question environment: `run_best_first_search`. You may not need to use all of them.

In [None]:

def run_astar_search(problem: PathCostProblem, max_expansions: int = 1000):
  """A* search.

  Use your implementation of `run_best_first_search`.
  """
  raise NotImplementedError("Implement me!")

#### Tests

In [None]:
def astar_test1():
  """If your results do not match the expected ones, make sure that you are tiebreaking 
  as described in the docstring for `run_best_first_search`."""
  grid_problem = GridProblem((0, 0))
  state_sequence, action_sequence, cost_sequence, num_expansions = run_astar_search(
      grid_problem)
  assert state_sequence == [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
                            (4, 2), (4, 3), (4, 4)]
  assert action_sequence == [
      'down', 'down', 'down', 'right', 'right', 'down', 'right', 'right'
  ]
  assert cost_sequence == [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0]
  assert num_expansions <= 20

astar_test1()

print('Tests passed.')

## Greedy Best-First Search


### Question
Use your implementation of `run_best_first_search` to implement GBFS.

For reference, our solution is **4** line(s) of code.

In addition to all of the utilities defined at the top of the colab notebook, the following functions are available in this question environment: `run_best_first_search`. You may not need to use all of them.

In [None]:

def run_greedy_best_first_search(problem: PathCostProblem,
                                 max_expansions: int = 1000):
  """GBFS.

  Use your implementation of `run_best_first_search`.
  """
  raise NotImplementedError("Implement me!")

#### Tests

In [None]:
def gbfs_test1():
  """If your results do not match the expected ones, make sure that you are tiebreaking
  as described in the docstring for `run_best_first_search`."""
  initial_state = (0, 0)
  problem = GridProblem(initial_state)
  state_sequence, action_sequence, cost_sequence, num_expansions = run_greedy_best_first_search(
      problem)
  assert state_sequence == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 4),
                            (2, 4), (3, 4), (4, 4)]
  assert action_sequence == [
      'right', 'right', 'right', 'right', 'down', 'down', 'down', 'down'
  ]
  assert abs(num_expansions - 8) <= 1

gbfs_test1()

print('Tests passed.')

## Path-cost Problem from Reward Problem


### Question
Implement the reduction from a reward problem to a path-cost problem **as in lecture**.

For reference, our solution is **30** line(s) of code.

In [None]:

def path_cost_problem_from_reward_problem(
    reward_problem: RewardProblem) -> PathCostProblem:
  """Reduce a reward maximization problem into a path search problem.

  You should take a close look that the class definition of `RewardProblem`, 
  since they will be handy. 
  Especially note that the horizon value is inclusive -- for a horizon value of $H$, 
  the agent should be allowed to step exactly $H$ number of steps.
  """
  raise NotImplementedError("Implement me!")

#### Tests

In [None]:
def path_reward_reduction_test1(horizon=2, max_expansions=1000):
  path_problem = path_cost_problem_from_reward_problem(ForageProblem())
  state_sequence, action_sequence, cost_sequence = run_uniform_cost_search(
      path_problem, max_expansions=max_expansions)
  print('actions', action_sequence)
  print('states', state_sequence)
  print('number of expansions', num_expansions)
  assert len(action_sequence) == horizon
  if horizon == 2:
    assert action_sequence == [1, 1]
    assert [s for s, h in state_sequence] == [1, 3, 1]

path_reward_reduction_test1()

print('Tests passed.')

## Visualize Reward Fields


### Utilities

The Fractal Problem and different reward fields. 
          You shoud read the code below to get rough idea what each reward field looks like.

**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:
import matplotlib.pyplot as plt

FractalProblemState = namedtuple("FractalProblemState", ["x", "y", "t"])


class FractalProblem(RewardProblem):
  """A base class for "fractal problem".

  Briefly, a fractal problem is as follows:
    - The state space is the entire 2D Euclidean space
    - The initial state is (0, 0)
    - At each step $t$ the agent is allowed to move alone one of the directions of size
      $$step_scale^{-t}$$.
    - At each step the agent receives a scalar reward based on a reward field.
  """

  def __init__(self,
               action_directions="8-neighbors",
               step_scale=0.5,
               horizon=6):
    """A base constructor for a fractal problem. 

    Args:
      action_directions: allowed directions to move at each step. 
      step_scale: each step length is scaled down by this amount.
    """
    super().__init__(FractalProblemState(0, 0, 1), horizon)
    if action_directions == "4-neighbors":
      self.action_directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    elif action_directions == "8-neighbors":
      self.action_directions = [
          (x, y)
          for x, y in itertools.product([-1, 0, 1], [-1, 0, 1])
          if (x, y) != (0, 0)
      ]
    else:
      self.action_directions = action_directions
    self.step_scale = step_scale
    self.field_of_view = self.compute_field_of_view()

  def compute_field_of_view(self, precision: float = 1.) -> Tuple[Tuple[float]]:
    """Computes a minimum rectangle that encloses all reachable states within horizon.

    Args:
      precision: the returned rectangle will have vertices that are integer 
        multiples of `precision`. Defaults to 1., which means that the rectangle
        has integral vertices.

    Returns:
      the lower-left and upper-right vertices that represents the rectangle.
    """
    extent = sum([self.step_scale**(t + 1) for t in range(self.horizon)])
    dir_extent = np.array(self.action_directions) * extent
    return (
        tuple(np.floor(np.min(dir_extent, axis=0) / precision) * precision),
        tuple(np.ceil(np.max(dir_extent, axis=0) / precision) * precision),
    )

  @property
  def field_of_view_size(self):
    return (self.field_of_view[1][0] - self.field_of_view[0][0],
            self.field_of_view[1][1] - self.field_of_view[0][1])

  def actions(self, state):
    scale = self.step_scale**state.t
    return [(x * scale, y * scale) for x, y in self.action_directions]

  def step(self, state, action):
    ax, ay = action
    return FractalProblemState(state.x + ax, state.y + ay, state.t + 1)

  def reward(self, state, action, next_state):
    coord = np.array([state.x, state.y])
    return self.reward_field(coord).item()

  def reward_field(self, coords: np.ndarray):
    """The reward(s) at `coords`. Handle a batch of coordinates for efficiency of plotting.

    Subclass must override to provide a meaningful reward field.

    Args:
      coords: an array with shape (2,) or (...batch, 2) denoting the coordinate(s)
        to compute the reward field.

    Returns:
      A scalar or an array of shape (...batch,)
    """
    return np.zeros(coords.shape[:-1])

  @property
  def rmax(self):
    return 1.

  def visualize_reward_field(self, resolution=(1000, 1000), show=False):
    """Visualize a reward field.

    Args:
      resolution: number of pixels (w, h) to discretize the field of view. 
    """
    fov = self.field_of_view
    coords = np.dstack(
        np.meshgrid(
            np.linspace(fov[0][0], fov[1][0], resolution[0]),
            np.linspace(fov[0][1], fov[1][1], resolution[1]),
        )).reshape(-1, 2)
    vals = self.reward_field(coords).reshape(resolution)
    plt.imshow(vals,
               origin="lower",
               extent=(fov[0][0], fov[1][0], fov[0][1], fov[1][1]))
    plt.colorbar()
    if show:
      plt.show()
    return self

  def visualize_plan(self,
                     state_sequence: StateSeq,
                     show=False,
                     **arrow_kwargs):
    """Visualize a plan on the reward field.

    Args:
      state_sequence: A sequence state moving in the fractal problem.
      arrow_kwargs: passed to `plt.arrow`.
    """
    for s1, s2 in zip(state_sequence[:-1], state_sequence[1:]):
      x1, y1, d = s1
      x2, y2, _ = s2
      dx, dy = x2 - x1, y2 - y1
      plt.arrow(x1,
                y1,
                dx,
                dy,
                length_includes_head=True,
                width=0.01 / d**0.5,
                fill=True,
                **arrow_kwargs)
    if show:
      plt.show()
    return plt


from scipy.stats import multivariate_normal


class GradientRewardFieldProblem(FractalProblem):
  """A fractal problem with a gradient reward field by adding multiple (scaled) 
  gaussian distributions.
  """

  def __init__(self,
               locs: Sequence[Tuple[int, int]] = ((0, 0),),
               covs: Union[float, Sequence[Union[float, np.ndarray]]] = 0.1,
               strengths: Union[float, Sequence[float]] = 1.,
               **kwargs):
    """A reward field with a mixture of guassian gradients.

    Args:
      locs: the centers of the gaussians
      covs: a scalar, a sequence of scalars, or a sequence of matrices 
        for the covariances of the gaussians.
      strenghts: a scalar or a sequence of scalars for the scalaring factors 
        for each guassian. 
    """
    super().__init__(**kwargs)
    self.locs = locs
    if np.isscalar(covs):
      covs = [covs] * len(self.locs)
    covs = [np.eye(2) * cov if np.isscalar(cov) else cov for cov in covs]
    self.covs = covs
    if np.isscalar(strengths):
      strengths = [strengths] * len(self.locs)
    self.strengths = strengths

  def reward_field(self, coords):
    return sum([
        multivariate_normal.pdf(coords, mean=loc, cov=cov) * strength
        for loc, cov, strength in zip(self.locs, self.covs, self.strengths)
    ])


class NoisyRewardFieldProblem(FractalProblem):
  """A fractal problem with a reward field sampled from iid Beta distributions 
  in a discretized grid.
  """

  def __init__(self, seed=0, bin_size=5e-2, beta_params=(1., 2.), **kwargs):
    super().__init__(**kwargs)
    # Compute number of bins required for the field of view
    self.nbins = tuple(int(s / bin_size) for s in self.field_of_view_size)
    # Initialize the random rewards within the field of view
    rng_state = np.random.RandomState(seed)
    self.random_locs = rng_state.beta(*beta_params, size=self.nbins)
    self.random_locs = np.pad(self.random_locs, [(1, 1), (1, 1)],
                              constant_values=0.)

  def reward_field(self, coords):
    binX, binY = tuple(
        np.digitize(
            coords[..., i],
            np.linspace(self.field_of_view[0][i], self.field_of_view[1][i],
                        self.nbins[i] + 1)) for i in range(2))
    return self.random_locs[binX, binY]


def get_fractal_problems() -> Dict[str, FractalProblem]:
  """We have defined here three fractal problems with different reward fields.

  You are encouraged to play with the class and function definitions above.
  But, DO NOT CHANGE THIS FUNCTION --- it exists to protect you from 
  accidentally changing the problems.
  """
  return {
      "reward-field-1":
          GradientRewardFieldProblem(locs=[(-1.3, -1.3), (-1.3, 1.3),
                                           (1.3, -1.3), (1.3, 1.3)],
                                     covs=[.3, .3, .3, .3],
                                     strengths=[2., 2., 2., 2.]),
      "reward-field-2":
          GradientRewardFieldProblem(locs=[(-1.3, -1.3), (-1.3, 1.3),
                                           (1.3, -1.3), (1.3, 1.3)],
                                     covs=[.25, .25, .25, .2],
                                     strengths=[1, 1, 1, 1.5]),
      "reward-field-3":
          NoisyRewardFieldProblem(seed=42, bin_size=5e-2,
                                  beta_params=(0.2, 2.)),
  }


# Feel free to play with these problem instances however you'd like
fractal_problems = get_fractal_problems()

### Question
We have defined for you three fractal problems in `get_fractal_problems()`.
          We have provided you below some code (also in the Colab notebook) to visualize the reward fields of these three problems. 
          You should run this visualization code in the Colab notebook and get an idea of what each reward field looks like.
          **You do not need to submit any code for this question --- it is not graded.**
          Instead, we will ask you questions in the following problems to confirm your understanding.


For reference, our solution is **3** line(s) of code.

In [None]:

def visualize_fractal_problems():
  """Visualize the fractal problems with different rewards fields.
  """
  for name, problem in get_fractal_problems().items():
    plt.title(name)
    problem.visualize_reward_field(show=True)
    plt.clf()

#### Tests

In [None]:

print('Tests passed.')

## Play with MCTS


### Utilities

Our implementation of MCTS.
**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:
##############################################################################
#
# MCTS
#
##############################################################################

import dataclasses


@dataclasses.dataclass(frozen=False, eq=False)
class MCTNode:
  """Node in the Monte Carlo search tree, keeps track of the children states."""
  state: State
  U: float
  N: int
  horizon: int
  parent: Optional['MCTNode']
  children: Dict['MCTNode', Action] = dataclasses.field(default_factory=dict)


def ucb(n: MCTNode, C: float = 1.4) -> float:
  """UCB for a node, note the C argument"""
  return np.inf if n.N == 0 else (n.U / n.N +
                                  C * np.sqrt(np.log(n.parent.N) / n.N))


def run_mcts_search(problem: RewardProblem,
                    C: float = 1.4,
                    iteration_budget: int = 1000,
                    step_budget: int = np.inf,
                    time_budget: float = np.inf):
  """A generic MCTS search implementation.

  Args:
    problem: a reward problem.
    C: the UCB parameter.
    iteration_budget: maximum iterations to run the search.
    step_budget: maximum number of allowed problems steps.
    time_budget: maximum allowed CPU time.

  Returns:
    state_sequence: A list of states.
    action_sequence: A list of actions.
    avg_reward_sequence: A list of rewards
    num_steps: An int.
  """
  if min(iteration_budget, step_budget, time_budget) == np.inf:
    raise ValueError("Must provide at least one budget")

  problem_step_count = 0

  class BudgetExceeded(Exception):
    pass

  def step_helper(state, action):
    """helper to track the problem's step count."""
    nonlocal problem_step_count
    problem_step_count += 1
    if problem_step_count > step_budget:
      raise BudgetExceeded("step budget exceeded")
    return problem.step(state, action)

  if time_budget != np.inf:

    def signal_handler(signum, frame):
      raise BudgetExceeded("time budget exceeded")

    prev_handler = signal.signal(signal.SIGPROF, signal_handler)
    signal.setitimer(signal.ITIMER_PROF, time_budget)

  ucb_fixed_C = functools.partial(ucb, C=C)

  def select(n: MCTNode) -> MCTNode:
    """select a leaf node in the tree"""
    if n.children:
      # print([(c, ucb(c)) for c in n.children])
      ucb_pick = max(n.children.keys(), key=ucb_fixed_C)
      # print('ucb_pick', ucb_pick, 'act', n.children[ucb_pick])
      return select(ucb_pick)
    else:
      return n

  def expand(n: MCTNode) -> MCTNode:
    """expand the leaf node by adding all its children states"""
    assert not n.children
    if not n.horizon == 0:
      for action in problem.actions(n.state):
        child_state = step_helper(n.state, action)
        new_node = MCTNode(state=child_state,
                           horizon=n.horizon - 1,
                           parent=n,
                           U=0,
                           N=0)
        n.children[new_node] = action
      child = random.choice(list(n.children.keys()))
      # print('expand', 'state', n.state, 'action', n.children[child])
      return child
    else:
      return n

  def simulate(node: MCTNode) -> float:
    """simulate the utility of current state by randomly picking a step"""
    state = node.state
    total_reward = 0
    for h in range(node.horizon, 0, -1):
      action = random.choice(problem.actions(state))
      child_state = step_helper(state, action)
      reward = problem.reward(state, action, child_state)
      # print('   rollout: h', h, 'action', action, 'reward', reward)
      total_reward += reward
      state = child_state
    # print('simulate', node.state, '->', total_reward)
    return total_reward

  def backup(n: MCTNode, value: float) -> None:
    """passing the utility back to all parent nodes"""
    if n.parent:
      # Need to include the reward on the action *into* n
      a = n.parent.children[n]
      r = problem.reward(n.parent.state, a, n.state)
      n.U += value + r
      n.N += 1
      backup(n.parent, value + r)
    else:
      n.N += 1

  root = MCTNode(state=problem.initial,
                 horizon=problem.horizon,
                 parent=None,
                 U=0,
                 N=0)

  try:
    i = 0
    while i < iteration_budget:
      leaf = select(root)
      child = expand(leaf)
      value = simulate(child)
      # print(i, leaf, child, value)
      backup(child, value)
      i += 1
  except BudgetExceeded:
    pass

  if time_budget != np.inf:
    signal.setitimer(signal.ITIMER_PROF, 0)
    signal.signal(signal.SIGPROF, prev_handler)

  # print('children of root', root.children)

  return finish_mcts_plan(problem, root)


def finish_mcts_plan(problem: RewardProblem, node: MCTNode):
  """Helper for run_mcts_search. Recover the plan. """
  state_sequence = [node.state]
  action_sequence = []
  reward_sequence = []

  while node.children:
    max_node = max(node.children,
                   key=lambda p: p.U / p.N if p.N > 0 else -np.inf)
    max_action = node.children.get(max_node)
    action_sequence.append(max_action)
    state_sequence.append(max_node.state)
    reward_sequence.append(
        problem.reward(node.state, max_action, max_node.state))
    node = max_node

  return state_sequence, action_sequence, reward_sequence


class ForageProblem(RewardProblem):

  def __init__(self, horizon=8):
    super().__init__("s1", horizon)

  def actions(self, state):
    return ["a1", "a2"]

  def step(self, state, action):
    """
    T(s, a) -> s'
    s1  a1  -> s3
    s1  a2  -> s2
    s2  a1  -> s4
    s2  a2  -> s4
    s3  a1  -> s1
    s3  a2  -> s4
    s4  a1  -> s1
    s4  a2  -> s1
    """
    transition = np.array([[3, 2], [4, 4], [1, 4], [1, 1]], dtype=int)

    state_idx, action_idx = int(state[1:]) - 1, int(action[1:]) - 1
    return f"s{transition[state_idx, action_idx]}"

  def reward(self, state, action, next_state):
    """
    R(s, a, s') = Food_At(s) - Length(s, a)
    s     Food_At
    s1      2
    s2      3
    s3      2
    s4      10
    (s, a) -> length
    s1  a1  -> 1
    s1  a2  -> 1
    s2  a1  -> 7
    s2  a2  -> 7
    s3  a1  -> 1
    s3  a2  -> 5
    s4  a1  -> 2
    s4  a2  -> 2
    """
    food_at = np.array([2, 3, 2, 10])
    length = np.array([[1, 1], [7, 7], [1, 5], [2, 2]], dtype=int)
    state_idx, action_idx = int(state[1:]) - 1, int(action[1:]) - 1
    return food_at[state_idx] - length[state_idx, action_idx]

  @property
  def rmax(self):
    return 8


def mcts_test1(horizon=2, iteration_budget=100):
  problem = ForageProblem()
  state_sequence, action_sequence, reward_sequence = run_mcts_search(
      problem, horizon, iteration_budget=iteration_budget)

  print('actions', action_sequence)
  print('states', state_sequence)

  if horizon == 2:
    assert state_sequence == [3, 1]
    assert action_sequence == [1, 1]

### Question
We have provided you an implementation of MCTS for reward problems. 
          You should make sure that you can run our implementation on the reward problems we have defined for you.
          You are also encouraged to look into our code to see under how our MCTS is implemented.
          **You do not need to submit any code for this question --- it is not graded.**
          Instead, we will ask you questions in the following problems to confirm your understanding.


For reference, our solution is **2** line(s) of code.

In [None]:

def visualize_mcts():
  """Solve one of the fractal problems and visualize the plan.
  """
  problem = get_fractal_problems()["reward-field-1"]
  plan = run_mcts_search(problem, iteration_budget=10000)
  problem.visualize_reward_field().visualize_plan(plan[0], show=True)

#### Tests

In [None]:

print('Tests passed.')