# Homework 7

## Imports and Utilities
**Note**: these imports and functions are available in catsoop. You do not need to copy them in.

In [None]:
from collections import defaultdict
import abc
import numpy as np


class MDP:
    """A Markov Decision Process."""

    @property
    @abc.abstractmethod
    def state_space(self):
        """Representation of the MDP state set.

        Unless otherwise stated, assume this is a set.
        """
        raise NotImplementedError("Override me")

    @property
    @abc.abstractmethod
    def action_space(self):
        """Representation of the MDP action set.

        Unless otherwise stated, assume this is a set.
        """
        raise NotImplementedError("Override me")

    @property
    def temporal_discount_factor(self):
        """Gamma, defaults to 1.
        """
        return 1.

    @property
    def horizon(self):
        """H, defaults to inf.
        """
        return float("inf")

    def state_is_terminal(self, state):
        """Designate certain states as terminal (done) states.

        Defaults to False.

        Args:
            state: A state.

        Returns:
            is_terminal : A bool.
        """
        return False

    @abc.abstractmethod
    def get_reward(self, state, action, next_state):
        """Return (deterministic) reward for executing action
        in state.

        Args:
            state: A current state.
            action: An action.
            next_state: A next state.

        Returns:
            reward : Single time step reward.
        """
        raise NotImplementedError("Override me")

    @abc.abstractmethod
    def get_transition_distribution(self, state, action):
        """Return a distribution over next states.

        Unless otherwise stated, assume that this returns
        a dictionary mapping states to probabilities. For
        example, if the state space were {0, 1, 2}, then
        this function might return {0: 0.3, 1: 0.2, 2: 0.5}.

        Args:
            state: A current state.
            action: An action.

        Returns:
            next_state_distribution: Distribution over next states.
        """
        raise NotImplementedError("Override me")

    def sample_next_state(self, state, action, rng=np.random):
        """Sample a next state from the transition distribution.

        This function may be overwritten by subclasses when the explicit
        distribution is too large to enumerate.

        Args:
            state: A state from the state space.
            action: An action from the action space.
            rng: A random number generator.

        Returns:
            next_state: A sampled next state from the state space.
        """
        next_state_dist = self.get_transition_distribution(state, action)
        next_states, probs = zip(*next_state_dist.items())
        next_state_index = rng.choice(len(next_states), p=probs)
        next_state = next_states[next_state_index]
        return next_state


class MarshmallowMDP(MDP):
    """The Marshmallow MDP."""

    @property
    def state_space(self):
        # (hunger level, marshmallow remains)
        return {(h, m) for h in {0, 1, 2} for m in {True, False}}

    @property
    def action_space(self):
        return {"eat", "wait"}

    @property
    def horizon(self):
        return 4

    def get_reward(self, state, action, next_state):
        next_hunger_level = next_state[0]
        return -(next_hunger_level**2)

    def get_transition_distribution(self, state, action):
        # Update marshmallow deterministically
        if action == "eat":
            next_m = False
        else:
            next_m = state[1]

        # Initialize next state distribution dict
        # Any state not included assumed to have 0 prob
        dist = defaultdict(float)

        # Update hunger
        if action == "wait" or state[1] == False:
            # With 0.75 probability, hunger stays the same
            dist[(state[0], next_m)] += 0.75
            # With 0.25 probability, hunger increases by 1
            dist[(min(state[0] + 1, 2), next_m)] += 0.25

        else:
            assert action == "eat" and state[1] == True
            # Hunger deterministically set to 1 after eating
            dist[(0, next_m)] = 1.0

        return dist

class ChaseMDP(MDP):
    """A 2D grid bunny chasing MDP."""

    @property
    def obstacles(self):
        return np.zeros((2, 3))  # by default, 2x3 grid with no obstacles

    @property
    def goal_reward(self):
        return 1

    @property
    def living_reward(self):
        return 0

    @property
    def height(self):
        return self.obstacles.shape[0]

    @property
    def width(self):
        return self.obstacles.shape[1]

    @property
    def state_space(self):
        pos = [(r, c) for r in range(self.height) for c in range(self.width)]
        return {(p1, p2) for p1 in pos for p2 in pos}

    @property
    def action_space(self):
        return {'up', 'down', 'left', 'right'}

    @property
    def temporal_discount_factor(self):
        return 0.9

    def action_to_delta(self, action):
        return {
            'up': (-1, 0),  # up,
            'down': (1, 0),  # down,
            'left': (0, -1),  # left,
            'right': (0, 1),  # right,
        }[action]

    def get_transition_distribution(self, state, action):
        # Discrete distributions, represented with a dict
        # mapping next states to probs.
        next_state_dist = defaultdict(float)

        agent_pos, goal_pos = state

        # Get next agent state
        row, col = agent_pos
        dr, dc = self.action_to_delta(action)
        r, c = row + dr, col + dc
        # Stay in place if out of bounds or obstacle
        if not (0 <= r < self.height and 0 <= c < self.width):
            r, c = row, col
        elif self.obstacles[r, c]:
            r, c = row, col
        next_agent_pos = (r, c)

        # Get next bunny state
        # Stay in same place with probability 0.5
        next_state_dist[(next_agent_pos, goal_pos)] += 0.5
        # Otherwise move
        row, col = goal_pos
        for (dr, dc) in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            r, c = row + dr, col + dc
            # Stay in place if out of bounds or obstacle
            if not (0 <= r < self.height and 0 <= c < self.width):
                r, c = row, col
            elif self.obstacles[r, c]:
                r, c = row, col
            next_goal_pos = (r, c)
            next_state_dist[(next_agent_pos, next_goal_pos)] += 0.5*0.25

        return next_state_dist

    def get_reward(self, state, action, next_state):
        agent_pos, goal_pos = next_state
        if agent_pos == goal_pos:
            return self.goal_reward
        return self.living_reward

    def state_is_terminal(self, state):
        agent_pos, goal_pos = state
        return agent_pos == goal_pos


class LargeChaseMDP(ChaseMDP):
    """A larger 2D grid bunny chasing MDP."""

    @property
    def obstacles(self):
        return np.array([
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 0, 0, 0, 0, 1, 1],
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 1, 1, 0, 1, 0, 0, 0],
            [0, 0, 0, 1, 0, 0, 1, 0, 0, 0],
            [0, 1, 1, 0, 0, 0, 0, 1, 1, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ])



## Expectimax Search


### Question
Complete the implementation of expectimax search for a finite horizon MDP.

For reference, our solution is **16** line(s) of code.

In [None]:
def expectimax_search(initial_state, mdp, horizon):
    """Use expectimax search to determine a next action.

    Note that we're just computing the single next action to
    take, we do not need to store the entire partial V.

    Horizon is given as a separate argument so that we can use
    expectimax search with receding horizon control, for example,
    even if mdp.horizon is inf.

    Args:
        initial_state: A state in the mdp.
        mdp: An MDP.
        horizon: An int horizon.

    Returns:
        action: An action in the mdp.
    """
    raise NotImplementedError("Implement me!")

### Tests

In [None]:
def test1_expectimax_search():
    mdp = MarshmallowMDP()
    assert expectimax_search((0, True), mdp, mdp.horizon) == "wait"
    assert expectimax_search((0, True), mdp, 1) == "eat"
    assert expectimax_search((1, True), mdp, mdp.horizon) == "eat"
    assert expectimax_search((2, True), mdp, mdp.horizon) == "eat"
    assert expectimax_search((1, True), mdp, 10) == "wait"

test1_expectimax_search()


def test2_expectimax_search():
    mdp = ChaseMDP()
    assert expectimax_search(((0, 0), (0, 1)), mdp, 1) == "right"
    assert expectimax_search(((0, 0), (0, 2)), mdp, 2) == "right"
    assert expectimax_search(((0, 0), (1, 0)), mdp, 1) == "down"
    assert expectimax_search(((0, 0), (1, 2)), mdp, 2) in ["right", "down"]
    assert expectimax_search(((1, 2), (0, 0)), mdp, 2) in ["up", "left"]

test2_expectimax_search()

print('Tests passed.')