# Miniproject 3

In [None]:

!pip install backports.cached_property

# Setup matplotlib animation
import matplotlib
matplotlib.rc('animation', html='jshtml')


## Imports and Utilities
**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:

from typing import (Callable, Iterable, List, Sequence, Tuple, Dict, Optional,
                    Any, Union, Set, ClassVar, Type, TypeVar)

from abc import abstractmethod, ABC
import collections
import textwrap
import math
import functools
import itertools
import random
import dataclasses
import numpy as np
import heapq as hq

try:
    from functools import cached_property
except ImportError:
    # Import for Colab (Python==3.7)
    from backports.cached_property import cached_property

import scipy.signal


def heatmap(data, ax=None, cbar_kw=None, cbarlabel="", **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data
        A 2D numpy array of shape (M, N).
    ax
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current axes or create a new one.  Optional.
    cbar_kw
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
    cbarlabel
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """
    import matplotlib.pylab as plt

    if ax is None:
        ax = plt.gca()

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    if cbar_kw is not None:
        cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
        cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticks(np.arange(data.shape[0]))

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(),
             rotation=-30,
             ha="right",
             rotation_mode="anchor")

    # Turn spines off and create white grid.
    for spine in ax.spines.values():
        spine.set_visible(False)

    ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
    ax.grid(which="minor", color="black", linestyle='-', linewidth=1.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im


def annotate_heatmap(im,
                     data=None,
                     valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None,
                     **textkw):
    """
    A function to annotate a heatmap.

    Parameters
    ----------
    im
        The AxesImage to be labeled.
    data
        Data used to annotate.  If None, the image's data is used.  Optional.
    valfmt
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
    textcolors
        A pair of colors.  The first is used for values below a threshold,
        the second for those above.  Optional.
    threshold
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
    **kwargs
        All other arguments are forwarded to each call to `text` used to create
        the text labels.
    """

    import matplotlib.ticker

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max()) / 2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts


State = Any
Observation = Any
Action = Any


class Problem(ABC):
    """The abstract base class for either a path cost problem or a reward problem."""

    @property
    @abstractmethod
    def initial(self) -> State:
        ...

    @abstractmethod
    def actions(self, state: State) -> Iterable[Action]:
        """Returns the allowed actions in a given state.

        The result would typically be a list. But if there are many
        actions, consider yielding them one at a time in an iterator,
        rather than building them all at once.
        """
        ...

    @abstractmethod
    def step(self, state: State, action: Action) -> State:
        """Returns the next state when executing a given action in a given
        state.

        The action must be one of self.actions(state).
        """
        ...


class PathCostProblem(Problem):
    """An abstract class for a path cost problem, based on AIMA.

    To formalize a path cost problem, you should subclass from this and
    implement the abstract methods. Then you will create instances of
    your subclass and solve them with the various search functions.
    """

    @abstractmethod
    def goal_test(self, state: State) -> bool:
        """Checks if the state is a goal."""
        ...

    @abstractmethod
    def step_cost(self, state1: State, action: Action, state2: State) -> float:
        """Returns the cost incurred at state2 from state1 via action."""
        ...

    def h(self, state: State) -> float:
        """Returns the heuristic value, a lower bound on the distance to goal."""
        return 0


class POMDP(Problem):
    """A generative-model-based POMDP."""

    @property
    def discount(self) -> float:
        """The discount factor."""
        return 1.

    @property
    def horizon(self) -> int:
        """The planning horizon."""
        return np.inf

    @abstractmethod
    def terminal(self, state: State) -> bool:
        """If this state is terminating (absorbing state)."""
        return False

    @abstractmethod
    def reward(self, state1: State, action: Action, state2: State) -> float:
        """Returns the reward given at state2 from state1 via action."""
        ...

    @abstractmethod
    def get_observation(self, state: State) -> State:
        """Sample an observation from current state."""
        ...


class MDP(POMDP):
    """An generative-model-based MDP."""

    def get_observation(self, state: State) -> State:
        """An MDP is fully observable."""
        return state


class SearchFailed(ValueError):
    """Raise this exception whenever a search must fail."""
    pass


# A useful data structure for best-first search
BFSNode = collections.namedtuple("BFSNode",
                                 ["state", "parent", "action", "cost", "g"])


def run_best_first_search(
    problem: PathCostProblem,
    get_priority: Callable[[BFSNode], float],
    step_budget: int = 10000
) -> Tuple[List[State], List[Action], List[float], int]:
    """A generic heuristic search implementation.

    Depending on `get_priority`, can implement A*, GBFS, or UCS.

    The `get_priority` function here should determine the order
    in which nodes are expanded. For example, if you want to
    use path cost as part of this determination, then the
    path cost (node.g) should appear inside of get_priority,
    rather than in this implementation of `run_best_first_search`.

    Important: for determinism (and to make sure our tests pass),
    please break ties using the state itself. For example,
    if you would've otherwise sorted by `get_priority(node)`, you
    should now sort by `(get_priority(node), node.state)`.

    Args:
      problem: a path cost problem.
      get_priority: a callable taking in a search Node and returns the priority
      step_budget: maximum number of `problem.step` before giving up.

    Returns:
      state_sequence: A list of states.
      action_sequence: A list of actions.
      cost_sequence: A list of costs.
      num_steps: number of taken `problem.step`s. Must be less than or equal to `step_budget`.

    Raises:
      error: SearchFailed, if no plan is found.
    """
    num_steps = 0
    frontier = []
    reached = {}

    root_node = BFSNode(state=problem.initial,
                        parent=None,
                        action=None,
                        cost=None,
                        g=0)
    hq.heappush(frontier, (get_priority(root_node), problem.initial, root_node))
    reached[problem.initial] = root_node
    num_expansions = 0

    while frontier:
        pri, s, node = hq.heappop(frontier)
        # If reached the goal, return
        if problem.goal_test(node.state):
            return (*finish_plan(node), num_steps)

        num_expansions += 1
        # Generate successors
        for action in problem.actions(node.state):
            if num_steps >= step_budget:
                raise SearchFailed(
                    f"Failed to find a plan in {step_budget} steps")
            child_state = problem.step(node.state, action)
            num_steps += 1
            cost = problem.step_cost(node.state, action, child_state)
            path_cost = node.g + cost
            # If the state is already in explored or reached, don't bother
            if not child_state in reached or path_cost < reached[child_state].g:
                # Add new node
                child_node = BFSNode(state=child_state,
                                     parent=node,
                                     action=action,
                                     cost=cost,
                                     g=path_cost)
                priority = get_priority(child_node)
                hq.heappush(frontier, (priority, child_state, child_node))
                reached[child_state] = child_node
    raise SearchFailed(f"Frontier exhausted after {num_steps} steps")


def finish_plan(node: BFSNode):
    """Helper for run_best_first_search."""
    state_sequence = []
    action_sequence = []
    cost_sequence = []

    while node.parent is not None:
        action_sequence.insert(0, node.action)
        state_sequence.insert(0, node.state)
        cost_sequence.insert(0, node.cost)
        node = node.parent
    state_sequence.insert(0, node.state)

    return state_sequence, action_sequence, cost_sequence


def run_astar_search(problem: PathCostProblem, step_budget: int = 10000):
    """A* search.

    Use your implementation of `run_best_first_search`.
    """
    get_priority = lambda node: node.g + problem.h(node.state)
    return run_best_first_search(problem, get_priority, step_budget=step_budget)


@dataclasses.dataclass(frozen=False, eq=False)
class MCTStateNode:
    """Node in the Monte Carlo search tree, keeps track of the children states."""
    # For MDP, this is a state; for POMDP, this is an observation
    obs: Union[State, Observation]
    N: int
    horizon: int
    parent: Optional['MCTChanceNode']
    children: Dict['MCTChanceNode',
                   Action] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass(frozen=False, eq=False)
class MCTChanceNode:
    U: float
    N: int
    parent: MCTStateNode
    children: Dict[State,
                   MCTStateNode] = dataclasses.field(default_factory=dict)


def ucb(n: MCTStateNode, C: float = 1.4) -> float:
    """UCB for a node, note the C argument"""
    return (np.inf if n.N == 0 else
            (n.U / n.N + C * np.sqrt(np.log(n.parent.N) / n.N)))


RolloutPolicy = Callable[[State], Action]


def random_rollout_policy(problem: MDP, state: State) -> Action:
    return random.choice(list(problem.actions(state)))


def run_mcts_search(problem: POMDP,
                    state: Optional[State] = None,
                    state_sampler: Iterable[State] = None,
                    horizon: Optional[int] = None,
                    C: float = np.sqrt(2),
                    iteration_budget: int = 100,
                    n_simulations: int = 10,
                    max_backup: bool = True,
                    rollout_policy: RolloutPolicy = None,
                    verbose: bool = False) -> Action:
    """A generic MCTS search implementation for MDP and POMDPs.

    For MDP, this is a standard MCTS implementation based on description in AIMA 
    (with some additional features). 
    For POMDP, this is an implementation of the POMCP algorithm.

    Args:
        problem: an MDP or POMDP.
        state: the initial state. If None, then `state_sampler` must be provided.
        state_sampler: an iterable of states. If None, then `state` must be provided. This is used to sample the initial state for POMCP.
        horizon: the horizon of the search. If None, then the horizon is set to the problem's horizon.
        C: the C parameter for UCB.
        iteration_budget: the maximum number of iterations.
        n_simulations: the number of simulations to run at each MCTS iteration. In AIMA, 
            this is set 1, but in general, we run multiple simulations to reduce variance.
        max_backup: whether to use max backup or sum backup.
        rollout_policy: the rollout policy. If None, then the random rollout policy is used.
        verbose: whether to print debug information.

    Returns:
        action: the best action to take at the initial state according to the search.
    """
    if state is None:
        state = problem.initial

    if state_sampler is None:
        state_sampler = itertools.repeat(state)

    if horizon is None:
        horizon = problem.horizon

    if rollout_policy is None:
        rollout_policy = functools.partial(random_rollout_policy, problem)

    ucb_fixed_C = functools.partial(ucb, C=C)

    rewards = []

    def select(n: MCTStateNode, state: State) -> Tuple[MCTStateNode, State]:
        """select a leaf node in the tree."""
        if n.children:
            # Select the best child, break ties randomly
            children = list(n.children.keys())
            random.shuffle(children)
            ucb_pick: MCTChanceNode = max(children, key=ucb_fixed_C)
            act = n.children[ucb_pick]
            next_state = problem.step(state, act)
            rewards.append(problem.reward(state, act, next_state))
            next_obs = problem.get_observation(
                next_state)  # For MDP this is same as next_state
            if next_obs not in ucb_pick.children:
                new_leaf = MCTStateNode(next_obs,
                                        horizon=n.horizon - 1,
                                        parent=ucb_pick,
                                        N=0)
                ucb_pick.children[next_obs] = new_leaf
                return new_leaf, next_state
            return select(ucb_pick.children[next_obs], next_state)
        return n, state

    def expand(n: MCTStateNode, state: State) -> Tuple[MCTStateNode, State]:
        """expand the leaf node by adding all its children actions."""
        assert not n.children
        if n.horizon == 0 or problem.terminal(state):
            return n, state
        for action in problem.actions(state):
            new_chance_node = MCTChanceNode(parent=n, U=0, N=0)
            n.children[new_chance_node] = action
        chance_node, action = random.choice(list(n.children.items()))
        next_state = problem.step(state, action)
        rewards.append(problem.reward(state, action, next_state))
        next_obs = problem.get_observation(
            next_state)  # For MDP this is same as next_state
        new_node = MCTStateNode(next_obs,
                                N=0,
                                horizon=n.horizon - 1,
                                parent=chance_node)
        chance_node.children[next_obs] = new_node
        return new_node, next_state

    def simulate(node: MCTStateNode, state: State) -> float:
        """simulate the utility of current state by taking a rollout policy."""
        total_reward = 0
        disc = 1
        h = node.horizon
        while h > 0 and not problem.terminal(state):
            action = rollout_policy(state)
            next_state = problem.step(state, action)
            reward = problem.reward(state, action, next_state)
            total_reward += disc * reward
            state = next_state
            disc = disc * problem.discount
            h -= 1
        return total_reward

    def backup(state_node: MCTStateNode, value: float) -> None:
        """passing the utility back to all parent nodes."""
        state_node.N += 1
        if state_node.parent:
            # Need to include the reward on the action *into* n
            parent_chance_node = state_node.parent
            parent_state_node = parent_chance_node.parent
            r = rewards.pop()
            future_val = r + problem.discount * value
            parent_chance_node.U += future_val
            parent_chance_node.N += 1
            if max_backup:
                bk_val = max(0 if n.N == 0 else n.U / n.N
                             for n in parent_state_node.children)
            else:
                bk_val = future_val
            backup(parent_state_node, bk_val)

    state = next(state_sampler)
    root = MCTStateNode(obs=problem.get_observation(state),
                        horizon=horizon,
                        parent=None,
                        N=0)

    i = 0
    while i < iteration_budget:
        state = next(state_sampler)
        assert len(rewards) == 0
        leaf, state = select(root, state)
        child, state = expand(leaf, state)
        value = np.mean([simulate(child, state) for _ in range(n_simulations)])
        backup(child, value)
        i += 1

    children = list(root.children.keys())
    random.shuffle(children)
    act = root.children[max(children, key=lambda p: p.U / p.N)]
    if verbose:
        print(
            {
                act: (c.U / c.N if c.N > 0 else 0, c.N)
                for c, act in root.children.items()
            }, act)
    return act


@dataclasses.dataclass(frozen=True, eq=True, order=True)
class PickupProblemState:
    robot_loc: Tuple[int, int]
    carried_patient: bool


OneWayBlock = collections.namedtuple('OneWayBlock', ['loc', 'dest'])


@dataclasses.dataclass(frozen=True)
class PickupProblem(PathCostProblem):
    """A simple deterministic pickup problem in a grid.

    The robot starts at some location and can move up, down, left, or right.
    There is a patient at some location.
    The goal is to pick up the patient and drop them off at the hospital.
    The robot can only move to a location if there is a path from the robot's
    current location to the destination that does not pass through any roadblocks.
    """
    grid_shape: Tuple[int, int]

    initial_robot_loc: Tuple[int, int] = (0, 0)

    patient_loc: Tuple[int, int] = (4, 0)  # initial location of the patient
    hospital_loc: Tuple[int, int] = (0, 3)

    one_ways: List[OneWayBlock] = dataclasses.field(default_factory=list)

    grid_act_to_delta: ClassVar = {
        "up": (-1, 0),
        "down": (1, 0),
        "left": (0, -1),
        "right": (0, 1)
    }
    all_grid_actions: ClassVar = tuple(grid_act_to_delta.keys())

    @property
    def initial(self) -> State:
        return PickupProblemState(self.initial_robot_loc,
                                  self.initial_robot_loc == self.patient_loc)

    @cached_property
    def _one_way_set(self) -> Set[OneWayBlock]:
        """Set of one-way blocks. Helps with fast lookup."""
        return set(self.one_ways)

    def actions(self, state: PickupProblemState) -> Iterable[Action]:
        """Actions from the current state: move up, down, left, or right unless blocked."""
        (r, c) = state.robot_loc
        actions = []
        for act in self.all_grid_actions:
            dr, dc = self.grid_act_to_delta[act]
            new_r, new_c = r + dr, c + dc
            if (new_r in range(self.grid_shape[0]) and
                    new_c in range(self.grid_shape[1]) and
                    OneWayBlock(state.robot_loc,
                                (new_r, new_c)) not in self._one_way_set):
                actions.append(act)
        return actions

    def step(self, state: PickupProblemState,
             action: Action) -> PickupProblemState:
        """We automatically pick up patient if we're on that square."""
        (r, c) = state.robot_loc
        dr, dc = self.grid_act_to_delta[action]
        return PickupProblemState(
            (r + dr, c + dc),
            state.carried_patient or self.patient_loc == (r + dr, c + dc),
        )

    def step_cost(self, state1, action, state2) -> float:
        """Cost of taking an action in a state. 

        Actually not used in this project, but we keep it here for completeness.
        """
        return 1.

    def goal_test(self, state: PickupProblemState) -> bool:
        """True if at hospital and holding patient."""
        return state.robot_loc == self.hospital_loc and state.carried_patient

    def render(self, state: PickupProblemState, ax=None):
        """Render the state as a grid."""
        import matplotlib.pyplot as plt
        if ax is None:
            ax = plt.gca()

        heatmap(np.zeros(self.grid_shape),
                ax=ax,
                cmap="YlOrRd",
                vmin=0,
                vmax=1,
                origin="upper")

        # Render the robot
        robot = plt.Circle(state.robot_loc[::-1], 0.5, color='blue')
        ax.add_patch(robot)

        # Render the patient
        patient_loc = (state.robot_loc
                       if state.carried_patient else self.patient_loc)
        patient = plt.Circle(patient_loc[::-1], 0.3, color='orange')
        ax.add_patch(patient)

        # Render the hospital
        plt.plot([self.hospital_loc[1]], [self.hospital_loc[0]],
                 marker='P',
                 color='r',
                 markersize=15)

        # Render the walls and one-way doors
        one_ways = set(self._one_way_set)
        while one_ways:
            one_way = one_ways.pop()
            rev = OneWayBlock(one_way.dest, one_way.loc)
            if rev in one_ways:
                one_ways.remove(rev)
                src, dst = one_way.loc, one_way.dest
                if src < dst:
                    dst, src = src, dst
                src, dst = np.array(src), np.array(dst)
                mid_pt = (src + dst) / 2
                delta = dst - src
                wall = plt.Rectangle(mid_pt[::-1] - delta / 2 +
                                     delta[::-1] * 0.05,
                                     delta[0] if delta[0] != 0 else 0.1,
                                     delta[1] if delta[1] != 0 else 0.1,
                                     color='black')
                ax.add_patch(wall)
            else:
                src, dst = np.array(one_way.loc), np.array(one_way.dest)
                mid_pt = (src + dst) / 2
                delta = dst - src
                base = mid_pt + delta * 0.07
                length = -delta * 0.2
                arrow = plt.arrow(base[1],
                                  base[0],
                                  length[1],
                                  length[0],
                                  width=0.2,
                                  length_includes_head=True,
                                  head_length=0.2,
                                  color='black')
                ax.add_patch(arrow)


conv2D = functools.partial(scipy.signal.convolve2d,
                           mode='same',
                           boundary='fill',
                           fillvalue=0)

FireGridT = np.ndarray  # boolean array


@dataclasses.dataclass(frozen=True)
class FireProcess:
    """A probabilistic model for the evolution of fire in a grid.

    At time step $t$, the probability of a cell being on fire is weighted probability 
    of the neighboring cell being on fire at time step $t-1$.
    """

    initial_fire_grid: np.ndarray

    fire_weights: np.ndarray = np.array([[0, 1, 0], [1, 4, 1], [0, 1, 0]])
    attenuation: float = 1.0

    rng: np.random.Generator = dataclasses.field(
        default_factory=np.random.default_rng)

    @cached_property
    def normalized_fire_weights(self) -> np.ndarray:
        return self.attenuation * self.fire_weights / np.sum(self.fire_weights)

    def dist(self, fire_grid: FireGridT) -> np.ndarray:
        """Given the fire grid at time t, return a new grid with marginal distributions of fire for t + 1."""
        next_fire_dist = conv2D(fire_grid, self.normalized_fire_weights)
        return np.clip(next_fire_dist, 0, 1)  # clip for numerical stability

    def sample(self, fire_grid: FireGridT) -> FireGridT:
        """Given the fire grid at time t, return a new grid that's a sample from the distribution."""
        return self.rng.binomial(1, self.dist(fire_grid)).astype(bool)

    def render(self, fire_grid: FireGridT, ax=None):
        heatmap(fire_grid, ax=ax, cmap="YlOrRd", vmin=0, vmax=1, origin="upper")


@dataclasses.dataclass(frozen=True, eq=True)
class FireMDPState(PickupProblemState):
    """A state in the fire MDP extends the pickup problem state with a fire grid."""

    fire_grid: np.ndarray

    # Below we implement __eq__ and __hash__ to make FireMDPState hashable.
    def __eq__(self, other):
        if not isinstance(other, FireMDPState):
            return False
        return (self.robot_loc == other.robot_loc and
                self.carried_patient == other.carried_patient and
                np.all(self.fire_grid == other.fire_grid))

    def __hash__(self):
        return hash((super().__hash__(), self.fire_grid.tobytes()))


T = TypeVar('T', bound='FireProblemCommon')


@dataclasses.dataclass(frozen=True)
class FireProblemCommon:
    """Common code for the fire MDP and POMDP problems."""

    pickup_problem: PickupProblem
    fire_process: FireProcess

    _horizon: int = np.inf
    _discount: float = 0.999

    step_reward: float = 0.

    burn_reward: float = -1
    goal_reward: float = 1.

    @property
    def horizon(self):
        return self._horizon

    @property
    def discount(self):
        return self._discount

    @property
    def initial_robot_loc(self):
        return self.pickup_problem.initial_robot_loc

    def robot_burned(self, state) -> bool:
        return bool(state.fire_grid[state.robot_loc])

    def succeeded(self, state) -> bool:
        return self.pickup_problem.goal_test(state)

    def reward(self, state1, action, state2) -> float:
        if self.robot_burned(state2):
            return self.burn_reward
        if self.succeeded(state2):
            return self.goal_reward
        return self.step_reward

    def terminal(self, state) -> bool:
        return (self.robot_burned(state) or self.succeeded(state))

    @property
    def grid_shape(self):
        return self.pickup_problem.grid_shape

    @classmethod
    def from_str(cls: Type[T],
                 env_s: str,
                 fire_process_kargs: Optional[Dict[str, Any]] = None,
                 **kwargs) -> T:
        """Create a problem from a grid string.

        Legend:
            . = empty
            F = fire
            < = one way block (can go left)
            > = one way block (can go right)
            ^ = one way block (can go up)
            v = one way block (can go down)
            X = wall (two way block)
            R = robot
            P = patient
            H = hospital
        Each line must be the same length, and starts and ends with a `|`.
        Each character is separated by a space.

        Warning:
            The user needs to make ssure that no cell location is a dead end (due to roadblocks)
            since our `terminal` condition does not check for empty available actions.
        """
        if fire_process_kargs is None:
            fire_process_kargs = {}

        lines = env_s.splitlines()
        assert all(len(line) == len(lines[0]) for line in lines)
        assert all(line[0] == '|' and line[-1] == '|' for line in lines)
        # remove the first and last character of each line
        lines = [line[1:-1] for line in lines]
        # split each line into a list of characters
        lines = [line[::2] for line in lines]
        w = len(lines[0]) // 2 + 1
        h = len(lines) // 2 + 1
        robot_loc = None
        patient_loc = None
        hospital_loc = None
        fire_grid = np.zeros((h, w), dtype=bool)
        one_ways = []
        for i, line in enumerate(lines):
            i2 = i // 2
            if i % 2 == 0:
                for j, c in enumerate(line):
                    j2 = j // 2
                    if j % 2 == 0:
                        if c == 'F':
                            fire_grid[i2, j2] = True
                        elif c == 'R':
                            robot_loc = (i2, j2)
                        elif c == 'P':
                            patient_loc = (i2, j2)
                        elif c == 'H':
                            hospital_loc = (i2, j2)
                        else:
                            assert c == '.'
                    else:
                        if c in '<X':
                            one_ways.append(OneWayBlock((i2, j2), (i2, j2 + 1)))
                        if c in '>X':
                            one_ways.append(OneWayBlock((i2, j2 + 1), (i2, j2)))
            if i % 2 == 1:
                for j, c in enumerate(line[::2]):
                    if c in 'vX':
                        one_ways.append(
                            OneWayBlock(loc=(i2 + 1, j), dest=(i2, j)))
                    if c in '^X':
                        one_ways.append(
                            OneWayBlock(loc=(i2, j), dest=(i2 + 1, j)))

        if robot_loc is None:
            raise ValueError("No robot location specified")
        if patient_loc is None:
            raise ValueError("No patient location specified")
        if hospital_loc is None:
            raise ValueError("No hospital location specified")

        pickup_problem = PickupProblem(fire_grid.shape, robot_loc, patient_loc,
                                       hospital_loc, one_ways)
        fire_process = FireProcess(fire_grid, **fire_process_kargs)
        return cls(pickup_problem, fire_process, **kwargs)


@dataclasses.dataclass(frozen=True)
class FireMDP(FireProblemCommon, MDP):
    """The completely observable fire problem."""

    @property
    def initial(self) -> FireMDPState:
        return FireMDPState(*dataclasses.astuple(self.pickup_problem.initial),
                            self.fire_process.initial_fire_grid)

    def actions(self, state: State) -> Iterable[Action]:
        return self.pickup_problem.actions(state)

    def step(self, state: FireMDPState, action: Action) -> FireMDPState:
        return FireMDPState(
            *dataclasses.astuple(self.pickup_problem.step(state, action)),
            self.fire_process.sample(state.fire_grid))

    def render(self, state: FireMDPState, ax=None):
        self.pickup_problem.render(state, ax=ax)
        self.fire_process.render(state.fire_grid, ax=ax)


def get_problem(name: str) -> MDP:
    """Return a problem instance by name."""

    params = {
        "maze":
            dict(env_s="""\
                |R < . X . > H|
                |            X|
                |. X .   .   .|
                |        X   X|
                |. X .   . X .|
                |    X   ^    |
                |. X .   . X .|
                |    X   ^   ^|
                |P   . > . > .|
                """,
                 fire_process_kargs=dict(fire_weights=np.array([
                     [0, 1, 0],
                     [1, 10, 1],
                     [0, 1, 0],
                 ])),
                 _horizon=20),
        "just_wait":
            dict(env_s="""\
                |.   R   .   H|
                |X   v   X   ^|
                |. X . X .   .|
                |    v       ^|
                |. X . X .   .|
                |    v       ^|
                |. X . X .   .|
                |    v       ^|
                |F X . X F   .|
                |    v       ^|
                |. X . X .   F|
                |    v       ^|
                |. X . X .   .|
                |X   v   X   ^|
                |P   .   .   .|
                """,
                 fire_process_kargs=dict(fire_weights=np.array([
                     [0, 1, 0],
                     [1, 20, 1],
                     [0, 1, 0],
                 ]))),
        "the_circle":
            dict(env_s="""\
                |R   .   .   H|
                |    X   X   v|
                |. X .   . X .|
                |            v|
                |. X F   . X .|
                |            v|
                |. X F   . X .|
                |    X   X   v|
                |P   . < . < .|
                """,
                 fire_process_kargs=dict(fire_weights=np.array([
                     [0, 0, 0],
                     [0, 1, 0],
                     [0, 0, 0],
                 ]))),
        "the_choice":
            dict(
                env_s="""\
                |.   .   F   F   F   F|
                |.   X   X   X   X   X|
                |. X .   .   .   .   .|
                |X                    |
                |R > .   .   F   .   .|
                |v                    |
                |. X .   .   .   .   .|
                |v   X   X   X   X   v|
                |. X .   F   F   . X .|
                |v                   v|
                |. X F   .   .   . X .|
                |v                   v|
                |. X F   .   .   . X .|
                |v                   v|
                |. X F   F   .   . X .|
                |v   X   X   X   X   v|
                |. > . > H   P < . < .|
                """,
                fire_process_kargs=dict(fire_weights=np.array([
                    [0, 1, 0],
                    [1, 4, 1],
                    [0, 1, 0],
                ]),),
            )
    }

    if name not in params:
        raise ValueError(f"Unknown problem name: {name}")

    params[name]["env_s"] = textwrap.dedent(params[name]["env_s"])
    return FireMDP.from_str(**params[name])


class Agent:
    """An agent that can act in an MDP or POMDP.

    A derived agent must keep track of its own internal state.
    """

    def reset(self):
        """Reset the agent's internal state."""
        pass

    @abstractmethod
    def act(self, obs: Union[Observation, State]) -> Action:
        """Return the agent's action given an observation. 
        For MDP agents, `obs` will be the complete state"""
        ...


@dataclasses.dataclass
class OpenLoopAgent(Agent):
    """Agent that just follows a fixed sequence of actions."""

    actions: Sequence[Action]

    t: int = dataclasses.field(default=0, init=False)

    def reset(self):
        self.t = 0

    def act(self, obs) -> Action:
        del obs  # observation is not used
        assert self.t < len(self.actions)
        a = self.actions[self.t]
        self.t += 1
        return self.actions


@dataclasses.dataclass
class RolloutLookaheadAgent(Agent):
    """MDP Agent that uses a rollout lookahead to decide what to do."""

    problem: MDP
    n_rollout_per_action: int = 10

    receding_horizon: int = None
    t: int = 0

    def reset(self):
        self.t = 0

    @property
    def planning_horizon(self):
        if self.receding_horizon is None:
            return self.problem.horizon - self.t
        return self.receding_horizon

    def act(self, state: State) -> Action:
        """Return the action that maximizes the expected reward."""
        self.t += 1
        actions = list(self.problem.actions(state))
        random.shuffle(actions)
        return max(actions, key=lambda a: self._rollout(state, a))

    def _rollout(self, state: State, action: Action) -> float:
        """Return the expected reward of taking action in state."""
        return sum(
            self._rollout_single(state, action) for _ in range(
                self.n_rollout_per_action)) / self.n_rollout_per_action

    def _rollout_single(self, state: State, action: Action) -> float:
        """simulate the utility of current state by taking a rollout policy."""
        total_reward = 0
        disc = 1
        t = 0
        planning_horizon = self.planning_horizon
        while t < planning_horizon and not self.problem.terminal(state):
            if t > 0:
                action = self.rollout_policy(state)
            next_state = self.problem.step(state, action)
            reward = self.problem.reward(state, action, next_state)
            total_reward += disc * reward
            state = next_state
            disc = disc * self.problem.discount
            t += 1
        return total_reward

    def rollout_policy(self, state: State) -> Action:
        """Return the action to take in state during rollout. 

        Subclass may override to implement rollout policy with preferred actions."""

        return random.choice(list(self.problem.actions(state)))


def benchmark_agent(problem: Union[MDP, POMDP],
                    agent: Agent,
                    n_repeats: int = 100,
                    verbose: bool = False,
                    max_steps: int = 50) -> List[float]:
    """Bencmark an agent on a problem by performing repeated experiments."""
    import tqdm
    total_rewards = []
    for _ in tqdm.tqdm(range(n_repeats)):
        *_, total_reward = run_agent_on_problem(problem,
                                                agent,
                                                max_steps=max_steps,
                                                verbose=verbose)
        total_rewards.append(total_reward)
    return total_rewards


def compare_agents(problem: Union[POMDP, MDP],
                   agents: Dict[str, Agent],
                   n_repeats: int = 30,
                   max_steps: int = 50,
                   verbose: bool = False):
    """Compare the performance of multiple agents on a problem.

    You probably want to extend this function for more detailed analysis of performance.

    Args:
        problem: The problem to run the agents on.
        agents: A dictionary mapping agent names to agents.
        n_repeats: The number of experiments to run for each setting.
        max_steps: The maximum number of steps to run each experiment. Note that this is different from the horizon of the problem,
            since we may do receding horizon planning, but don't want to run each epsiode forever.
        verbose: If True, print the reward for each experiment.s
    """
    for agent in agents:
        if isinstance(agent, tuple):
            agent, name = agent
        else:
            name = agent.__class__.__name__
        print(f"Running {name}...")
        rewards = benchmark_agent(problem,
                                  agent,
                                  max_steps=max_steps,
                                  n_repeats=n_repeats,
                                  verbose=verbose)
        print(f"Mean reward: {np.mean(rewards):.2f} +- {np.std(rewards):.2f}")
        print(f"Median reward: {np.median(rewards):.2f}")
        print(f"Min reward: {np.min(rewards):.2f}")
        print(f"Max reward: {np.max(rewards):.2f}")


def run_agent_on_problem(
    problem: Union[MDP, POMDP],
    agent: Agent,
    verbose: bool = True,
    max_steps: int = np.inf,
) -> Tuple[Sequence[State], Sequence[Action], float]:
    """Runs the agent on the problem and returns the trajectory."""
    agent.reset()
    state = problem.initial
    obs = problem.get_observation(state)
    state_sequence = [state]
    action_sequence = []
    total_reward = 0
    while not problem.terminal(state) and len(state_sequence) < min(
            problem.horizon, max_steps):
        action = agent.act(obs)
        next_state = problem.step(state, action)
        reward = problem.reward(state, action, next_state)
        total_reward += reward * problem.discount**len(state_sequence)
        if verbose:
            print(
                f"Action={action} reward={reward} total_reward={total_reward}")
        obs = problem.get_observation(next_state)
        state = next_state
        action_sequence.append(action)
        state_sequence.append(state)
    return state_sequence, action_sequence, total_reward


def animate_trajectory(problem: Union[MDP, POMDP],
                       trajectory: Tuple[Sequence[State], Sequence[Action]]):
    """Visualizes a trajectory.

    Args:
        problem: The problem.
        trajectory: A tuple of state and action sequences.

    Returns:
        A matplotlib animation.
    """
    import matplotlib.pyplot as plt
    import matplotlib.animation

    state_sequence, action_sequence, *_ = trajectory

    fig, ax = plt.subplots()

    total_reward = 0.

    def animate(i):
        ax.clear()
        ax.set_aspect('equal')
        nonlocal total_reward
        if i == 0:
            total_reward = 0
            ax.set_title(f"Step {i}: begin, total_reward={total_reward:.2f}")
        elif i < len(state_sequence):
            action = action_sequence[i - 1]
            reward = problem.reward(state_sequence[i - 1], action,
                                    state_sequence[i])
            total_reward += reward * problem.discount**i
            ax.set_title(
                f"Step {i}: action={action}, "
                f"reward={reward:.2f}, total_reward={total_reward:.2f}")

        problem.render(state_sequence[i], ax=ax)

    anim = matplotlib.animation.FuncAnimation(fig,
                                              animate,
                                              frames=len(state_sequence),
                                              interval=500)
    return anim





@dataclasses.dataclass(frozen=True, eq=True)
class FirePOMDPState(FireMDPState):
    """A state in the Fire POMDP."""

    drone_loc: Tuple[int, int]

    def __eq__(self, other):
        if not isinstance(other, FireMDPState):
            return False
        return (self.robot_loc == other.robot_loc and
                self.carried_patient == other.carried_patient and
                np.array_equal(self.fire_grid, other.fire_grid))

    def __hash__(self):
        return hash((super().__hash__(), self.drone_loc))


class FirePOMDPObservation(np.ndarray):
    """Observation of the fire grid at the drone's location. 

    For each entry in the fire grid:
        0 = no fire
        1 = fire
        np.nan = not observed
    """

    def __hash__(self):
        return hash(self.tobytes())

    def __eq__(self, other):
        return np.array_equal(self, other, equal_nan=True)

    @staticmethod
    def unknown(shape: Tuple[int, int]) -> "FirePOMDPObservation":
        """Returns an completely unknown observation grid of the given shape."""
        return np.full(shape, np.nan).view(FirePOMDPObservation)


def mask_centered_at(shape: Tuple[int, int], loc: Tuple[int, int],
                     dist: int) -> np.ndarray:
    """Return a mask array centered at loc and extending dist away to each direction."""
    mask = np.zeros(shape, dtype=bool)
    mask[max(0, loc[0] - dist):min(shape[0], loc[0] + dist + 1),
         max(0, loc[1] - dist):min(shape[1], loc[1] + dist + 1)] = True
    return mask


@dataclasses.dataclass(frozen=True)
class FirePOMDP(FireProblemCommon, POMDP):
    """A POMDP version of the Fire problem.

    In this setup, prior to each episode start, fire propagations for an unknown amount of time 
    according to the fire process dynamics as in the MDP case.
    And when the epsiode starts, fire stops propagating. 
    But, the agent does not know where is fire. 
    It can only observe the fire grid cells within certain distances to the robot and the drone.
    """
    # Fire has propogated ~Geomtric(initial_fire_spread_param) - 1 number of steps
    # Must be in (0, 1], where 1 means the fire never spreads even before the agent starts
    initial_fire_spread_param: float = 1

    # robot can see this many cells away
    # 0 means the robot can only see the cell it is in
    robot_view_distance: int = 1

    # drone can see this many cells away
    # 0 means the drone can only see the cell it is in
    drone_view_distance: int = 0

    # drone can fly this many cells away
    drone_fly_distance: float = 3

    @property
    def initial_drone_loc(self):
        """Initially the drone starts off at the same location as the robot."""
        return self.initial_robot_loc

    @property
    def initial(self) -> FirePOMDPState:
        # Fire has propagated for an unknown amount of time, before the agent starts
        spread_time = self.fire_process.rng.geometric(
            self.initial_fire_spread_param) - 1
        fire_grid = self.fire_process.initial_fire_grid
        for _ in range(spread_time):
            fire_grid = self.fire_process.sample(fire_grid)

        return FirePOMDPState(*dataclasses.astuple(self.pickup_problem.initial),
                              fire_grid=fire_grid,
                              drone_loc=self.initial_drone_loc)

    def drone_actions(self, robot_loc: Tuple[int, int]) -> List[Action]:
        """A drone can fly to any cell within the flying distance."""
        # return state.robot_loc
        grid_coords = itertools.product(range(self.grid_shape[0]),
                                        range(self.grid_shape[1]))
        return [(r, c)
                for r, c in grid_coords
                if (math.hypot(robot_loc[0] - r, robot_loc[1] -
                               c) <= self.drone_fly_distance)]

    def actions(self, state: FirePOMDPState) -> Iterable[Action]:
        """The robot can move to a neighboring cell and the drone can move to any cell."""
        for robot_act in self.pickup_problem.actions(state):
            robot_loc = self.pickup_problem.step(state, robot_act).robot_loc
            for drone_act in self.drone_actions(robot_loc):
                yield (robot_act, drone_act)

    def step(self, state: FirePOMDPState, action: Action) -> FirePOMDPState:
        robot_action, drone_action = action
        return FirePOMDPState(
            *dataclasses.astuple(self.pickup_problem.step(state, robot_action)),
            state.fire_grid, drone_action)

    def get_observation(self, state: FirePOMDPState) -> FirePOMDPObservation:
        """Returns the observation of the fire grid."""
        # Create a mask around the self.robot_view_distance of the robot
        robot_mask = mask_centered_at(self.grid_shape, state.robot_loc,
                                      self.robot_view_distance)
        # Create a mask around the self.drone_view_distance of the drone
        drone_mask = mask_centered_at(self.grid_shape, state.drone_loc,
                                      self.drone_view_distance)
        # Combine the masks
        mask = np.logical_or(robot_mask, drone_mask)
        # Return the masked fire grid
        return np.where(mask, state.fire_grid,
                        np.nan).view(FirePOMDPObservation)

    def render(self, state: FireMDPState, ax=None):
        """Render the fire grid and the robot and drone locations, 
        then highlight the observed region."""

        import matplotlib.pyplot as plt
        if ax is None:
            ax = plt.gca()

        self.pickup_problem.render(state, ax=ax)
        self.fire_process.render(state.fire_grid, ax=ax)

        overlay = np.isnan(self.get_observation(state)).astype(float)
        # Draw gray overlay for the observed location
        ax.imshow(overlay,
                  alpha=overlay * 0.5,
                  cmap="binary",
                  vmin=0,
                  vmax=1,
                  interpolation='nearest')
        # Draw a blue rectangle around the drone location to highlight it
        drone_rect = plt.Rectangle(
            (state.drone_loc[1] - 0.5, state.drone_loc[0] - 0.5),
            1,
            1,
            fill=False,
            edgecolor='blue',
            lw=4)
        ax.add_patch(drone_rect)


def get_problem_part2(name: str) -> FirePOMDP:
    """Return a problem instance by name."""

    params = {
        "only_fire":
            dict(
                env_s="""\
                |R   .   H|
                |         |
                |.   .   .|
                |         |
                |.   F   .|
                |         |
                |P   .   F|
                """,
                fire_process_kargs=dict(fire_weights=np.array([
                    [0, 1, 0],
                    [1, 10, 1],
                    [0, 1, 0],
                ])),
                initial_fire_spread_param=0.3,
            ),
        "the_circle2":
            dict(
                env_s="""\
                |R   .   H|
                |    X    |
                |. X . X .|
                |         |
                |F X F X .|
                |    X    |
                |P   .   .|
                """,
                fire_process_kargs=dict(fire_weights=np.array([
                    [0, 1, 0],
                    [1, 4, 1],
                    [0, 1, 0],
                ])),
                initial_fire_spread_param=
                1.,  # fire doesn't spread before the agent starts
            ),
        "the_choice2":
            dict(env_s="""\
                |.   .   F   F   F|
                |X   X   X   X   X|
                |. > .   .   .   .|
                |                 |
                |. > .   .   F   .|
                |                 |
                |R > .   .   .   P|
                |    X   X   X    |
                |. > .   F   .   H|
                """,
                 fire_process_kargs=dict(fire_weights=np.array([
                     [0, 1, 0],
                     [1, 5, 1],
                     [0, 1, 0],
                 ]),),
                 initial_fire_spread_param=0.4),
    }

    if name not in params:
        raise ValueError(f"Unknown problem name: {name}")

    params[name]["env_s"] = textwrap.dedent(params[name]["env_s"])
    return FirePOMDP.from_str(**params[name])





## Determinized Min-cost Path Problem


### Utilities


**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:


@dataclasses.dataclass(frozen=True, eq=True, order=True)
class DeterminizedFireMDPState(PickupProblemState):
    """A state for the DeterminizedFireMDP.

    The state is a pair of the PickupProblemState and a time step $t$.
    """
    time: int = 0


**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:

import scipy.special


@dataclasses.dataclass(frozen=True)
class FireMRF:
    """A Markov Random Field for the fire grid."""

    unitary_potentials: np.ndarray
    correlation_potential: np.ndarray

    @cached_property
    def log_unitary_potentials(self) -> np.ndarray:
        """Return the log of the unitary potentials."""
        return np.log(self.unitary_potentials)

    @cached_property
    def log_correlation_potential(self) -> np.ndarray:
        """Return the log of the correlation potentials."""
        return np.log(self.correlation_potential)

    @staticmethod
    def default(shape: Tuple[int, int]) -> "FireMRF":
        """Return a default MRF for a fire grid of the given shape."""
        return FireMRF(
            unitary_potentials=np.full((*shape, 2), 0.5),
            correlation_potential=np.array([[0.7, 0.3], [0.3, 0.7]]),
        )

    @property
    def grid_shape(self) -> Tuple[int, int]:
        """Return the shape of the fire grid."""
        return self.unitary_potentials.shape[:2]


def shift(array: np.ndarray,
          offset: Sequence[int],
          constant_values: float = 0) -> np.ndarray:
    """Returns copy of array shifted by offset, with fill using constant.

    Taken from https://stackoverflow.com/a/70297929.
    """
    array = np.asarray(array)
    offset = np.atleast_1d(offset)
    assert len(offset) == array.ndim
    new_array = np.empty_like(array)

    def slice1(o):
        return slice(o, None) if o >= 0 else slice(0, o)

    new_array[tuple(slice1(o) for o in offset)] = (array[tuple(
        slice1(-o) for o in offset)])

    for axis, o in enumerate(offset):
        new_array[(slice(None),) * axis + (
            slice(0, o) if o >= 0 else slice(o, None),)] = constant_values

    return new_array


@dataclasses.dataclass
class LBPResult:
    marginals: np.ndarray
    niter: int
    converged: bool
    msgs: Optional[np.ndarray] = None


def fire_mrf_lbp_marginals(log_unitary_potentials: np.ndarray,
                           log_correlation_potential: np.ndarray,
                           initial_msgs: Optional[np.ndarray] = None,
                           max_iters: int = 20,
                           return_msgs: bool = False,
                           rtol=1e-5) -> LBPResult:
    """Run loopy belief propagation on the fire MRF.

    This procedure assumes a 2D grid of random variables of shape (h, w). with
    pairwise interactions between adjacent variables and a grid of unitary potentials.
    This function aims to be as efficient as possible, so it does not use 
    any loops. It uses vectorized numpy to compute the messages and the marginals.

    Args:
        log_unitary_potentials: The log of the unitary potentials of a FireMRF.
        log_correlation_potential: The log of the correlation potential of a FireMRF.
        initial_msgs: The initial messages. If None, will be initialized automatically.
        max_iters: The maximum number of iterations to run.
        return_msgs: Whether to return the passed messages, useful for "warm-starting" the messages.
        rtol: The relative tolerance for convergence test.

    Returns:
        LBPResult: The marginal probabilities of the random variables.
    """
    h, w, x = log_unitary_potentials.shape

    if initial_msgs is None:
        msgs = [np.zeros((h, w, x)) for _ in range(4)]
    else:
        msgs = initial_msgs

    del initial_msgs

    def _msgs_incoming(msgs):
        dirs = ((0, 1, 0), (1, 0, 0), (-1, 0, 0), (0, -1, 0))
        return [shift(msg, dir) for msg, dir in zip(msgs, dirs)]

    logsumexp = scipy.special.logsumexp
    log_correlation_pot_rev = log_correlation_potential.T

    log_rtol = np.log1p(rtol)

    prev_sum_msgs_incoming = None

    converged = False
    for niter in range(max_iters):

        # incoming messages from all neighbors, shape: [(h, w, x)] * 4
        msgs_incoming = _msgs_incoming(msgs)
        sum_msgs_incoming = sum(msgs_incoming)

        # Convergence test
        if prev_sum_msgs_incoming is not None:
            rerr = np.linalg.norm(
                (sum_msgs_incoming - prev_sum_msgs_incoming).ravel(), ord=1)
            if rerr < log_rtol:
                converged = True
                break
        prev_sum_msgs_incoming = sum_msgs_incoming

        # Temp terms involving all incoming messages
        unitary_term = (log_unitary_potentials[:, :, :, np.newaxis] +
                        sum_msgs_incoming[:, :, :, np.newaxis])
        tmp1 = unitary_term + log_correlation_potential  # shape: (h, w, xi, xj)
        tmp2 = unitary_term + log_correlation_pot_rev  # shape: (h, w, xi, xj)

        # compute next messages
        msgs = [
            logsumexp(t - mi[..., np.newaxis], axis=2)
            for t, mi in zip([tmp1, tmp1, tmp2, tmp2], reversed(msgs_incoming))
        ]

        # Normalize msgs
        msgs -= logsumexp(msgs, axis=-1, keepdims=True)

    # incoming messages from all neighbors, shape: (h, w, x)
    sum_msgs_incoming = sum(_msgs_incoming(msgs))
    marginals = log_unitary_potentials + sum_msgs_incoming
    marginals -= logsumexp(marginals, axis=-1, keepdims=True)
    marginals = np.exp(marginals)

    return LBPResult(marginals=marginals,
                     niter=niter,
                     converged=converged,
                     msgs=msgs if return_msgs else None)


@dataclasses.dataclass
class FireMRFGibbsSampler:
    """A Gibbs sampler for the FireMRF."""

    fire_mrf: FireMRF

    initial_fire_grid: Optional[np.ndarray] = None
    observation_grid: Optional[np.ndarray] = None

    burn_in_steps: int = 200
    n_parellel_chains: int = 5
    num_steps_per_sample: int = 100

    def sampling_step(self, fire_grid: np.ndarray,
                      var_loc: Tuple[int, int]) -> np.ndarray:
        """Performs a single step of Gibbs sampling on the fire grid.

        Given the current configuration of the fire grid, this function samples a 
        new configuration for the variable at `var_loc`, conditioned on the neighbors 
        of the variable. It returns the new fire grid.

        Args:
            fire_grid: The current state of the fire grid.
            var_loc: The location of the variable to update.
        """
        w, h = fire_grid.shape
        r, c = var_loc
        neighbor_pots = []
        log_correlation_pot = self.fire_mrf.log_correlation_potential
        if r - 1 >= 0:
            neighbor_pots.append(log_correlation_pot[fire_grid[r - 1, c], :])
        if r + 1 < w:
            neighbor_pots.append(log_correlation_pot[:, fire_grid[r + 1, c]])
        if c - 1 >= 0:
            neighbor_pots.append(log_correlation_pot[fire_grid[r, c - 1], :])
        if c + 1 < h:
            neighbor_pots.append(log_correlation_pot[:, fire_grid[r, c + 1]])
        neighbor_pots = np.sum(neighbor_pots, axis=0)
        log_posterior = (neighbor_pots +
                         self.fire_mrf.log_unitary_potentials[var_loc])
        log_posterior -= scipy.special.logsumexp(log_posterior)
        new_fire_grid = fire_grid.copy()
        new_val = np.random.choice(2, p=np.exp(log_posterior))
        new_fire_grid[r, c] = new_val
        return new_fire_grid

    def samples_stream(self) -> Iterable[np.ndarray]:
        """Performs Gibbs sampling on a fire MRF, yields an infinite stream of samples."""

        if np.all(~np.isnan(self.observation_grid)):
            # If we observed everything, don't need to sample!
            while True:
                yield self.observation_grid

        # If no initial fire grid is provided, we initialize it randomly
        if self.initial_fire_grid is None:
            fire_grids = list(
                np.random.binomial(
                    1,
                    0.5,
                    size=(self.n_parellel_chains, *self.fire_mrf.grid_shape),
                ))
        else:
            fire_grids = [
                self.initial_fire_grid.copy()
                for i in range(self.n_parellel_chains)
            ]

        if self.observation_grid is not None:
            sample_coords = np.array(np.where(np.isnan(
                self.observation_grid))).T
            observation = self.observation_grid[~np.isnan(self.observation_grid
                                                         )]
            for fire_grid in fire_grids:
                fire_grid[~np.isnan(self.observation_grid)] = observation
        else:
            sample_coords = np.meshgrid(
                *[np.arange(s) for s in self.fire_mrf.grid_shape])

        # Generate the infinite stream of samples
        steps = 0
        while True:
            for i in range(self.n_parellel_chains):
                r, c = sample_coords[np.random.choice(len(sample_coords))]
                fire_grids[i] = self.sampling_step(fire_grids[i], (r, c))

            steps += 1
            if (steps > self.burn_in_steps and
                    steps % self.num_steps_per_sample == 0):
                for i in range(self.n_parellel_chains):
                    yield fire_grids[i]

### Question
Please complete the implementation of `DeterminizedFireMDP`. In particular, you should:
- Complete the function `fire_dist_at_time` to compute the log-likelihood of each cell being on fire at time $t$ given the true fire state at time $0$. 
- Using your implementation of `fire_dist_at_time`, complete the function `step_cost`.
- Complete the rest of the `DeterminizedFireMDP` and code the heuristic function `h` based on description above.
    

For reference, our solution is **68** line(s) of code.

In [None]:
@dataclasses.dataclass(frozen=True)
class DeterminizedFireMDP(PathCostProblem):
    """Determinized version of the fire MDP --- tries to find the solution path 
    that is most likely to succeed.    
    """
    pickup_problem: PickupProblem
    fire_process: FireProcess

    # Additional cost for each step.
    # Can be 0 but we might have 0-cost arcs if the success probability is 1.
    action_cost = 1e-6

    # Use this to cache precomputed fire distributions, so we don't have to recompute them.
    fire_dists_cache: Dict[int, np.ndarray] = dataclasses.field(
        init=False,
        default_factory=dict,
    )

    def __post_init__(self):
        assert (self.pickup_problem.grid_shape ==
                self.fire_process.initial_fire_grid.shape)

    @property
    def initial(self) -> DeterminizedFireMDPState:
        return DeterminizedFireMDPState(
            *dataclasses.astuple(self.pickup_problem.initial),
            time=0,
        )

    def actions(self, state: DeterminizedFireMDPState) -> Iterable[Action]:
        raise NotImplementedError("Implement me!")

    def step(self, state: DeterminizedFireMDPState,
             action: Action) -> State:
        """We automatically pick up patient if we're on that square."""
        raise NotImplementedError("Implement me!")

    def goal_test(self, state: DeterminizedFireMDPState) -> bool:
        """True if at hospital and holding patient."""
        raise NotImplementedError("Implement me!")

    def step_cost(self, state1: DeterminizedFireMDPState, action: Action,
                  state2: DeterminizedFireMDPState) -> float:
        """Induce a small action_cost and the negative log-likelihood of no fire."""
        raise NotImplementedError("Implement me!")

    def fire_dist_at_time(self, t: int) -> np.ndarray:
        """Return the marginal distribution of fire grid at time $t$."""

        raise NotImplementedError("Implement me!")

        if t not in self.fire_dists_cache:
            if t == 0:
                dist = ...  # TODO: Implement this.
            else:
                dist = ...  # TODO: Implement this.
            self.fire_dists_cache[t] = dist
        return self.fire_dists_cache[t]

    def h(self, state: DeterminizedFireMDPState) -> float:
        """heuristic based on the manhattan distance to the patient and hospital."""
        raise NotImplementedError("Implement me!")

## Determinized Fire MDP Agent


### Question
Please complete the implementation of `FireMDPDeterminizedAStarAgent`. 
Note that we have filled in most of the implementation for you --- including 
the call to `run_astar_search` from HW01. 
All you need to implement is the `determinized_problem` method.


For reference, our solution is **27** line(s) of code.

In [None]:
@dataclasses.dataclass(frozen=True)
class FireMDPDeterminizedAStarAgent(Agent):
    """Agent that uses A* to plan a path to the goal in a determinized 
    version of the problem. Does not need any internal state since we 
    re-determinize the problem at each step.
    """

    problem: FireMDP
    step_budget: int = 10000

    def determinized_problem(self,
                             state: FireMDPState) -> DeterminizedFireMDP:
        """Returns a determinized approximation of the fire MDP."""
        raise NotImplementedError("Implement me!")

    def act(self, state: FireMDPState) -> Action:
        problem = self.determinized_problem(state)
        try:
            plan = run_astar_search(problem, self.step_budget)
        except SearchFailed:
            print("Search failed, performing a random action")
            return random.choice(list(self.problem.actions(state)))
        return plan[1][0]

## MCTS Agent


### Question
Please complete the implementation of `MCTSAgent`. 

_Hint: You can pass in the `self.planning_horizon` to `run_mcts_search`, 
to handle both infinite-horizon problems (by receding-horizon planning) and finite-horizon problems._


For reference, our solution is **42** line(s) of code.

In [None]:
@dataclasses.dataclass
class MCTSAgent(Agent):
    """Agent that uses Monte Carlo Tree Search to plan a path to the goal.

    The agent simply wraps `run_mcts_search`, and it should work for any MDP.
    """

    problem: MDP

    # An optional receding horizon to use for the planning
    # If not provided, the problem must have a finite horizon
    receding_horizon: Optional[int] = None

    C: float = np.sqrt(2)
    iteration_budget: int = 1000

    t: int = dataclasses.field(default=0, init=False)

    def __post_init__(self):
        if self.receding_horizon is None:
            assert self.problem.horizon != np.inf

    def reset(self):
        self.t = 0

    @property
    def planning_horizon(self) -> int:
        """Returns the planning horizon for the current time step."""
        if self.receding_horizon is None:
            return self.problem.horizon - self.t
        return self.receding_horizon

    def act(self, state: State) -> Action:
        """Return the action to take at state."""
        raise NotImplementedError("Implement me!")

## Exact Marginalization


### Question
Write a function `fire_mrf_exact_marginalize` to compute the exact marginals by multiplying all potentials and then marginalizing.
We don't expect this function to work for large grids, but it should work for grids that are smaller than 3 x 3.

For reference, our solution is **37** line(s) of code.

In [None]:
def fire_mrf_exact_marginalize(fire_mrf: FireMRF) -> np.ndarray:
    """Computes the exact marginal distributions of a given FireMRF.

    This function simply sums over all possible configurations of the fire grid to compute the
    marginal distributions of each variables. It won't work for large grids.
    """
    raise NotImplementedError("Implement me!")

### Tests

In [None]:
def explicit_marginalize_test(fire_mrf_exact_marginalize, fire_mrf: FireMRF,
                              results: np.ndarray) -> bool:
    marginals = fire_mrf_exact_marginalize(fire_mrf)
    assert np.allclose(marginals, results, atol=1e-4)

explicit_marginalize_test(fire_mrf_exact_marginalize, FireMRF(unitary_potentials=np.array([[[0.5, 0.5],   [0.3, 0.7]]], dtype=np.float64), correlation_potential=np.array([[0.7, 0.3],  [0.3, 0.7]], dtype=np.float64)), np.array([0.58, 0.7 ], dtype=np.float64))
print('Tests passed.')

## Loopy Belief Propagation


### Question
Our exact marginalization is very slow for larger grids. 
So instead, we will use loopy belief propagation to compute the marginals.

We have provided you with a function `fire_mrf_lbp_marginals` that implements loopy belief 
propagation on a `FireMRF` to compute the marginals approximately.
This function is the same as the belief propagation you saw from the lecture, except that it is 
works on our FireMRF model only. 
By specializing to our 2D grid model, we have taken care to make it very efficient 
by avoiding Python loops and using NumPy operations instead.
But, this function does not handle any observation. 

Please complete the implementation of `fire_mrf_lbp_conditionals`.
This function should take the same arguments as `fire_mrf_lbp_marginals`, with the addition of an
argument `observations` that is a `FirePOMDPObservation`.
A `FirePOMDPObservation` is just a 2d NumPy array --- please see the docstring 
for `FirePOMDPObservation` for more details.

_Hints:_ 
- You can use `np.isnan(observation)` to find the indices of the unobserved cells.
- You can reduce `fire_mrf_lbp_conditionals` into a call to `fire_mrf_lbp_marginals` by
incorporating the observation into the unitary potentials.


For reference, our solution is **17** line(s) of code.

In [None]:
def fire_mrf_lbp_conditionals(fire_mrf: FireMRF,
                              observation: FirePOMDPObservation,
                              max_iters: int = 20,
                              return_msgs: bool = False,
                              rtol=1e-5) -> np.ndarray:
    """Wrapper around `fire_mrf_lbp_marginals` that additionally handles 
    observation, to compute the conditional probabilities of a FireMRF.

    This function incorprates the observation into the unitary potential, and then 
    runs loopy belief propagation on the resulting MRF.
    In order to be computatioinally efficient, this function avoids the use of 
    Python loops, and uses numpy operations to prepare the unitary potentials.

    Args:
        fire_mrf: The FireMRF.
        observation: The observed observation. shape (h, w)
        max_iter, return_msgs, rtol: Same as `fire_mrf_lbp_marginals`.

    Returns:
        The marginal probabilities of fire in each cell, shape (h, w).
    """
    raise NotImplementedError("Implement me!")

### Tests

In [None]:
def fire_mrf_lbp_test(fire_mrf_lbp_conditionals, fire_mrf: FireMRF,
                      observation: np.ndarray, results: np.ndarray) -> bool:
    marginals = fire_mrf_lbp_conditionals(fire_mrf, observation)
    assert np.allclose(marginals, results, atol=1e-4)

fire_mrf_lbp_test(fire_mrf_lbp_conditionals, FireMRF(unitary_potentials=np.array([[[0.5, 0.5],   [0.3, 0.7],   [0.5, 0.5]],   [[0.5, 0.5],   [0.5, 0.5],   [0.5, 0.5]],   [[0.5, 0.5],   [0.5, 0.5],   [0.5, 0.5]]], dtype=np.float64), correlation_potential=np.array([[0.7, 0.3],  [0.3, 0.7]], dtype=np.float64)), np.array([[    0., np.nan, np.nan],  [np.nan,     1., np.nan],  [np.nan, np.nan,     0.]], dtype=np.float64), np.array([[0.   , 0.7  , 0.58 ],  [0.5  , 1.   , 0.532],  [0.5  , 0.5  , 0.   ]], dtype=np.float64))
print('Tests passed.')

## MLO Determinized FirePOMDP


### Utilities


**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:


@dataclasses.dataclass(frozen=True, eq=True, order=True)
class FirePOMDPBeliefState(PickupProblemState):
    fire_grid: FirePOMDPObservation
    drone_loc: Tuple[int, int]


class RewardProblem(MDP):
    """An abstract class for a finite-horizon reward problem, potentially discounted.

    It is essentially a deterministic MDP.
    """
    pass


@dataclasses.dataclass
class FirePOMDPBeliefStateAgent(Agent):
    """An abstract base agent that tracks the belief state at each step. 

    The agent tracks the belief state as a FirePOMDPBeliefState instance.
    Further, it assumes that the fire is distributed according to a `FireMRF`.
    """

    problem: FirePOMDP
    fire_mrf: FireMRF

    # An optional receding horizon to use for the planning
    # If not provided, the problem must have a finite horizon
    receding_horizon: Optional[int] = None

    t: int = dataclasses.field(default=0, init=False)

    belief_state: FirePOMDPBeliefState = dataclasses.field(
        default=None,
        init=False,
    )

    def __post_init__(self):
        if self.receding_horizon is None:
            assert self.problem.horizon != np.inf

    @property
    def planning_horizon(self):
        if self.receding_horizon is None:
            return self.problem.horizon - self.t
        return self.receding_horizon

    def reset(self):
        """Reset the agent to its initial state.

        In particular, we need to:
        - Reset the `self.t` to 0
        - Reset the belief state to the initial belief state
        """
        self.t = 0

        # Initial belief state
        fire_belief = FirePOMDPObservation.unknown(
            self.problem.pickup_problem.grid_shape)
        pickup_problem_initial = self.problem.pickup_problem.initial

        # No fire at the robot's location
        fire_belief[pickup_problem_initial.robot_loc] = 0

        self.belief_state = FirePOMDPBeliefState(
            *dataclasses.astuple(self.problem.pickup_problem.initial),
            fire_grid=fire_belief,
            drone_loc=self.problem.initial_drone_loc,
        )

    def act(self, obs: FirePOMDPObservation) -> Action:
        """Take an action while maintaining the belief state."""
        self.t += 1

        # Update belief state by observation
        next_fire_belief = np.where(
            np.isnan(self.belief_state.fire_grid), obs,
            self.belief_state.fire_grid).view(FirePOMDPObservation)
        self.belief_state = dataclasses.replace(self.belief_state,
                                                fire_grid=next_fire_belief)

        robot_action, drone_action = self._act(self.belief_state)

        # Update belief state by action
        self.belief_state = FirePOMDPBeliefState(
            *dataclasses.astuple(
                self.problem.pickup_problem.step(self.belief_state,
                                                 robot_action)),
            self.belief_state.fire_grid,
            drone_action,
        )

        return (robot_action, drone_action)

    @abstractmethod
    def _act(self, belief_state: FirePOMDPBeliefState) -> Action:
        """Take an action given the current belief state.

        Subclasses must implement this method.
        """
        ...


def firepomdp_random_rollout_policy(problem, state: FirePOMDPState) -> Action:
    """A random rollout policy for the fire POMDP.

    This policy is different from the simple `random_rollout_policy` in that it 
    first uniformly samples a robot action and then samples a drone action that 
    is consistent with the robot action.
    We do this because a robot action might have different number of available drone
    actions. If we sampled uniformly from all possible actions, then some robot 
    action is preferred over others.

    Args:
        problem: a problem instance similar to the `FirePOMDP`. It must have an 
            `actions` method that returns a sequence of tuples of robot and 
            drone actions.
        state: The current complete FirePOMDP state.

    Returns:
        A tuple of robot action and drone action.
    """
    actions = collections.defaultdict(list)
    for ra, da in problem.actions(state):
        actions[ra].append(da)
    robot_action = random.choice(list(actions.keys()))
    drone_action = random.choice(actions[robot_action])
    return (robot_action, drone_action)


def state_sampler(fire_mrf: FireMRF,
                  belief_state: FirePOMDPBeliefState) -> Iterable[FireMDPState]:
    """Sample a stream of completely observed states from the belief state, 
    assuming that fire is distributed according to the given MRF.

    Args:
        fire_mrf: The MRF that describes the distribution of fire.
        belief_state: A belief state.

    Yields:
        An infinite sequence complete state sampled from the belief state.
    """
    gibbs = FireMRFGibbsSampler(fire_mrf=fire_mrf,
                                observation_grid=belief_state.fire_grid)
    for fire_belief in gibbs.samples_stream():
        yield dataclasses.replace(belief_state, fire_grid=fire_belief)


@dataclasses.dataclass
class FirePOMDPRolloutLookaheadAgent(FirePOMDPBeliefStateAgent):
    """POMDP Agent that uses a rollout lookahead to decide what to do."""

    n_rollout_per_action: int = 10

    def _act(self, belief_state: FirePOMDPState) -> Action:
        sampler = state_sampler(self.fire_mrf, belief_state)
        # Initialize actions to the set of all possible actions
        # Note that this is specific to the FirePOMDP
        # In general, we cannot assume that the set of actions
        # is the same for all states sampled from the current belief state.
        state = next(sampler)
        actions = collections.defaultdict(list)
        for ra, da in problem.actions(state):
            actions[ra].append(da)
        robot_actions = list(actions.keys())
        random.shuffle(robot_actions)
        # Put back the state into the sampler so that we don't waste it
        sampler = itertools.chain([state], sampler)

        action_rewards = {}
        for robot_action in robot_actions:
            drone_actions = actions[robot_action]
            total_rewards = []
            for _ in range(self.n_rollout_per_action):
                action = (robot_action, random.choice(drone_actions))
                # Sample a state from the belief state
                state = next(sampler)
                total_rewards.append(self._rollout_single(state, action))
            # Compute the average reward for this action
            action_rewards[robot_action] = np.mean(total_rewards)

        # Return the action with the highest average reward
        best_robot_action = max(robot_actions, key=lambda a: action_rewards[a])
        random_drone_action = random.choice(actions[best_robot_action])
        return (best_robot_action, random_drone_action)

    def _rollout_single(self, state: FireMDPState, action: Action) -> float:
        """simulate the utility of current state by taking a rollout policy."""
        total_reward = 0
        disc = 1
        t = 0
        planning_horizon = self.planning_horizon
        while t < planning_horizon and not self.problem.terminal(state):
            if t > 0:
                action = firepomdp_random_rollout_policy(self.problem, state)
            next_state = self.problem.step(state, action)
            reward = self.problem.reward(state, action, next_state)
            total_reward += disc * reward
            state = next_state
            disc = disc * self.problem.discount
            t += 1
        return total_reward

### Question
Please complete the implementation of `MLODeterminiziedFirePOMDP`, 
a deterministic reward problem that is the MLO approximation of the `FirePOMDP`.

Recall that in MLO approximation, the next state is determined by the most likely observation.
Also, recall that our agent can observe one or more grid cells at each step.
Technically, the most likely observation is the 
configuration of the observed grid cells with the highest probability.
However, computing the joint distribution of the observed cells is expensive. 
Therefore, we make another approximation: We consider each cell individually and use the most likely configuration of that cell as its most likely observation.
By using this approximation, we can use loopy belief propagation from previous 
section to compute the most likely observation for each cell and then use 
those observations to determine the next state.


For reference, our solution is **72** line(s) of code.

In [None]:
@dataclasses.dataclass(frozen=True)
class MLODeterminiziedFirePOMDP(RewardProblem):
    """Determinized version of the fire MDP --- tries to find the solution path 
    that is most likely to succeed.    
    """

    problem: FirePOMDP

    _initial: FirePOMDPBeliefState

    fire_mrf: FireMRF
    lbp_niters: int = 10

    # Use this to cache precomputed fire distributions, so we don't have to recompute them.
    fire_marginals_cache: Dict[int, np.ndarray] = dataclasses.field(
        default_factory=dict)

    @property
    def pickup_problem(self) -> PickupProblem:
        return self.problem.pickup_problem

    @property
    def horizon(self) -> int:
        return self.problem.horizon

    @property
    def discount(self) -> float:
        return self.problem.discount

    def actions(self, state: FirePOMDPBeliefState) -> Iterable[Action]:
        return self.problem.actions(state)

    def get_fire_marginals(self,
                           fire_grid: FirePOMDPObservation) -> np.ndarray:
        if fire_grid not in self.fire_marginals_cache:
            self.fire_marginals_cache[
                fire_grid] = fire_mrf_lbp_conditionals(
                    self.fire_mrf, fire_grid, self.lbp_niters)
        return self.fire_marginals_cache[fire_grid]

    @property
    def initial(self) -> FirePOMDPBeliefState:
        return self._initial

    def step(self, state: FirePOMDPBeliefState,
             action: Action) -> FirePOMDPBeliefState:
        """Transition according to the MLO approximation."""
        raise NotImplementedError("Implement me!")

    def reward(self, state1: FirePOMDPBeliefState, action: Action,
               state2: FirePOMDPBeliefState) -> float:
        """Reward according to the MLO approximation."""
        raise NotImplementedError("Implement me!")

    def terminal(self, state: FirePOMDPBeliefState) -> bool:
        """Return True if all states in the belief state are terminal."""
        raise NotImplementedError("Implement me!")

## MLO Determinized FirePOMDP


### Question
Please complete the implementation of `FirePOMDP_MLO_MCTSAgent`.
At each step, this agent:
- Applies MLO approximation by casting the POMDP to an `MLODeterminiziedFirePOMDP` and 
- Uses MCTS to search for a good action under the MLO approximation.

Once you have the implementation, try to run the agent on the problems defined in `get_problem_part2`, 
and see how it performs. 
You may use the same utilities (e.g., `run_agent_on_problem` and `animate_trajectory`) as in part 1.

As a sanity check for correctness, you should see that the agent achieves a reward $>0.9$ 
most of the time in `get_problem_part2("only_fire")`, as long as there exists a path that 
allows the agent to pick up the patient and then reach the goal.

_Tips for using MCTS_:
- We recommend using `pomdp_random_rollout_policy` as the rollout policy for MCTS instead of the plain random policy. Please see the docstring of that function for more details.
- You would want to use `n_simulations=1` since the environment is deterministic. We also recommend setting `max_backup=False` to perform the Monte-Carlo backup instead of the Bellman backup.


For reference, our solution is **44** line(s) of code.

In [None]:
@dataclasses.dataclass
class FirePOMDP_MLO_MCTSAgent(FirePOMDPBeliefStateAgent):
    """Agent that uses A* to plan a path to the goal in a determinized 
    version of the problem. Does not need any internal state since we 
    re-determinize the problem at each step.
    """

    lbp_niters: int = 20

    C: float = np.sqrt(2)
    iteration_budget: int = 100

    # Use this to cache precomputed fire distributions, so we don't have to recompute them.
    fire_marginals_cache: Dict[int, np.ndarray] = dataclasses.field(
        default_factory=dict,
        init=False,
    )

    def _act(self, belief_state: FirePOMDPBeliefState) -> Action:
        """Returns the best action according to MLO and searching using MCTS."""
        raise NotImplementedError("Implement me!")