# HW 4.1 Starter Code v2

See pset for deliverables.

In [None]:
from gym.spaces.discrete import Discrete as DiscreteSpace
import gym
import matplotlib.pyplot as plt
import numpy as np
import time
from collections import defaultdict, namedtuple
_ = np.seterr(divide='ignore', invalid='ignore')

## Domain

In [None]:
class FlickeringSAREnv(gym.Env):
    """A grid world with a number of different cell types that affect
    only the local dynamics of the agent. With some darkness_prob, the
    observations are null, otherwise, the state is fully observed. The
    state is just the position of the agent. Actions are just up/down/
    left/right.

    Parameters
    ----------
    darkness_prob : float
    """
    # Actions
    ACTIONS = UP, DOWN, LEFT, RIGHT = range(4)
    DELTAS = [(-1, 0), (1, 0), (0, -1), (0, 1), (0, 0)]

    # Grid
    CELL_TYPES = EMPTY, ICE, STICKY, TRAP, CONVEYOR = \
                 E, I, S, T, C = range(5)
    # May generalize in the future, but for now, just one grid
    GRID = np.array([
        [T, T, T, T, T, T, T],
        [T, E, E, C, E, E, T],
        [T, E, S, E, I, E, T],
        [T, E, E, I, I, S, T],
        [T, E, E, E, I, S, T],
        [T, C, E, C, C, E, T],
        [T, T, T, T, T, T, T],
    ], dtype=int)

    def __init__(self, darkness_prob=0.):
        self.action_space = DiscreteSpace(len(self.ACTIONS))
        self._darkness_prob = darkness_prob
        self._state = None # set in reset
        self._rng = None # set in seed
        self._delta_dists = {cell_type : {action : \
            self.get_delta_distribution(cell_type, action) \
            for action in self.ACTIONS} \
            for cell_type in self.CELL_TYPES}
        super().__init__()

    def seed(self, seed):
        self._rng = np.random.RandomState(seed)
        self.action_space.seed(seed)

    def reset(self):
        """Reset the state (position of agent in grid)

        Returns
        -------
        observation : (int, int) or None
            None for darkness
        """
        assert self._rng is not None, "Must call seed before reset"
        r = self._rng.choice(self.GRID.shape[0])
        c = self._rng.choice(self.GRID.shape[1])
        self._state = (r, c)
        return self._get_observation()

    def step(self, action):
        """Update the state given the action

        Returns
        -------
        observation : (int, int) or None
            None for darkness
        reward : float
            Always 0, including just for convention
        done : bool
            True if currently in trap
        debug_info : dict
            Including just for convention
        """
        assert self._state is not None, "Must call reset before step"
        # Get the current cell type
        cell_type = self.GRID[self._state[0], self._state[1]]
        # Get the local transition table
        delta_dist = self._delta_dists[cell_type][action]
        # Sample a delta
        dr, dc = sample_from_dict(delta_dist, self._rng)
        # Update state
        r, c = self._state
        assert 0 <= r + dr < self.GRID.shape[0] and \
               0 <= c + dc < self.GRID.shape[1]
        self._state = (r + dr, c + dc)
        # If we were in a trap, now we're done
        done = (cell_type == self.TRAP)
        obs = self._get_observation()
        return obs, 0., done, {}

    def _get_observation(self):
        """Handle darkness
        """
        # Flip a coin for darkness
        if self._rng.uniform() < self._darkness_prob:
            return None
        return tuple(self._state)

    @classmethod
    def get_delta_distribution(cls, cell_type, action):
        """This is what we'll try to recover through learning
        """
        expected_delta = cls.DELTAS[action]

        if cell_type == cls.EMPTY:
            return { expected_delta : 1.0 }

        if cell_type == cls.ICE:
            return { delta : 1./len(cls.DELTAS) \
                     for delta in cls.DELTAS }

        if cell_type == cls.STICKY:
            return { expected_delta : 0.1, (0, 0) : 0.9 }

        if cell_type == cls.TRAP:
            return { (0, 0) : 1.0 }

        if cell_type == cls.CONVEYOR:
            return { (-1, 0) : 1.0 }

        raise Exception(f"Unrecognized cell type {cell_type}")
        

def sample_from_dict(dict_probs, rng):
    """Helper utiliity
    """
    assert abs(sum(dict_probs.values()) - 1.) < 1e-6, \
        "Probabilities do not sum to 1."
    choices, probs = zip(*dict_probs.items())
    choice_idx = rng.choice(len(choices), p=probs)
    return choices[choice_idx]

## Belief Propagation

In [None]:
RV = namedtuple("RV", ["name", "dim"]) # Random Variable

class Potential(namedtuple("Potential", ["rvs", "table", "hash_val"])):
    """The same as the Potential from last pset, but with more efficient hashing
    """
    __slots__ = ()

    def __new__(cls, rvs, table):
        hash_val = hash(tuple(rvs)) ^ hash(table.tobytes())
        return super(Potential, cls).__new__(cls, rvs, table, hash_val)

    def __hash__(self):
        return self.hash_val

    def __eq__(self, other):
        return hash(self) == hash(other)


def run_belief_prop(rvs, potentials, max_iters=100):
    """Run belief propagation given random variables and potentials.
    Return the final messages.
    """
    # Initialize messages
    msgs = defaultdict(dict)
    for pot in potentials:
        for rv in pot.rvs:
            msgs[rv][pot] = np.ones(rv.dim) / rv.dim
            msgs[pot][rv] = np.ones(rv.dim) / rv.dim

    # Main loop
    for it in range(max_iters):
        new_msgs = {}

        # Update variables
        for rv in rvs:
            # Compute belief
            belief = np.prod([msgs[pot][rv] for pot in msgs[rv]], axis=0)
            # Update outgoing messages
            new_msgs[rv] = {pot : safe_divide(belief, msgs[pot][rv]) \
                            for pot in msgs[rv]}

        # Update factors
        for pot in potentials:
            # Compute belief
            belief = pot.table.copy()
            all_idxs = list(range(len(pot.rvs)))
            for rv in msgs[pot]:
                msg = msgs[rv][pot]
                idx = pot.rvs.index(rv)
                belief = np.einsum(belief, all_idxs, msg, [idx], all_idxs)
            # Update outgoing messages
            new_msgs[pot] = {}
            for rv in msgs[pot]:
                msg = msgs[rv][pot]
                idx = pot.rvs.index(rv)
                msg_inv = safe_divide(1., msg)
                msg_rv = np.einsum(belief, all_idxs, msg_inv, [idx], [idx])
                new_msgs[pot][rv] = msg_rv

        # Renormalize messages for numerical stability
        converged = True
        for src in new_msgs:
            for snk in new_msgs[src]:
                old_msg = msgs[src][snk]
                new_msg = safe_divide(new_msgs[src][snk],
                                      new_msgs[src][snk].sum())
                new_msgs[src][snk] = new_msg
                if converged and (np.any(np.abs(new_msg - old_msg) > 1e-6)):
                    converged = False

        if converged:
            break

        # Update messages
        msgs = new_msgs

    else:
        print("WARNING: BP did not converge")

    return msgs

def safe_divide(x, y):
    """Do x / y where at least one of x or y is a numpy array,
    but convert any resulting infinities or NaNs to 0, so that
    anything divided by 0 is 0.
    """
    z = x / y
    z[np.isinf(z) | np.isnan(z)] = 0.
    return z

## Learning

In [None]:
def learn_model(data, dark_prob, max_num_iters=10):
    """Expectation-maximization

    The model that we're trying to learn is parameterized
    by the delta distributions delta_dist[cell_type][action]
    """
    # Keep track of all models for plotting
    em_step_to_model = {}

    # Initialize uniformly
    deltas = FlickeringSAREnv.DELTAS
    delta_dist = {cell_type : {action : {delta : 1./len(deltas) \
                  for delta in deltas} \
                  for action in FlickeringSAREnv.ACTIONS} \
                  for cell_type in FlickeringSAREnv.CELL_TYPES}

    for it in range(max_num_iters):
        print(f"## Starting EM iteration {it}")
        em_step_to_model[it] = delta_dist

        # E Step: compute marginals.
        marginal_seq, action_seq = [], []

        for episode, (observations, actions) in enumerate(data):
            print(f"Running E step on episode {episode}/{len(data)}", end='\r')
            # Compute pairwise marginals over consecutive states.
            # This is a list of dicts (cell_type, delta) -> prob.
            marginals = compute_marginals(observations, actions, delta_dist,
                                          dark_prob)
            marginal_seq.extend(marginals)
            action_seq.extend(actions)

        # M Step: get MLE parameters
        new_delta_dist = compute_mle_delta_dist(marginal_seq, action_seq)

        # Check for convergence
        max_dist = 0.
        for cell_type in FlickeringSAREnv.CELL_TYPES:
            for action in FlickeringSAREnv.ACTIONS:
                old_dist = delta_dist[cell_type][action]
                new_dist = new_delta_dist[cell_type][action]
                if set(old_dist.keys()) != set(new_dist.keys()):
                    converged = False
                    break
                for k in old_dist:
                    max_dist = max(max_dist, abs(old_dist[k] - new_dist[k]))
        converged = max_dist < 1e-4
        print("\nChange in model between iterations:", max_dist)
        if converged:
            print(f"EM converged after {it} iterations.")
            for i in range(it+1, max_num_iters+1):
                em_step_to_model[i] = delta_dist
            break

        # Update dist
        delta_dist = new_delta_dist

    else:
        print("WARNING: EM did not converge.")
        em_step_to_model[it+1] = delta_dist

    return em_step_to_model

### E Step

In [None]:
def compute_marginals(observations, actions, delta_dist, dark_prob):
    """Helper for the E step of EM. Given a list of observations
    and corresponding actions, and the current model parameters
    delta_dist, and the known observation model parameter dark_prob,
    run inference to determine the marginal distributions over
    (cell_type, delta), one per time step.

    Create a factor graph with the following variables:
        - One observation variable for each observation
        - One state variable for each observation
        - One action variable for each action
    and the following potentials:
        - One observation potential per observation,
          relating the state and observation.
          (Use dark_prob!)
        - One transition potential per action,
          relating the state, action, and next state.
          (Use delta_dist!)

    Notes:
        - FlickeringSAREnv.ACTIONS are the actions
        - FlickeringSAREnv.DELTAS are all possible local moves
        - FlickeringSAREnv.GRID holds the cell type of each state

    Run belief propagation on that factor graph. Use the
    resulting messages to compute the joint distribution
    over pairs of (state, next state).

    Finally, use that joint pairwise distribution over states
    to determine a distribution over (cell_type, delta) for
    each time step. Note that the cell type is determined from
    the state: FlickeringSAREnv.GRID[state[0], state[1]] is the
    cell type; and delta is determined by the pair of states:
    delta = (next_state[0] - state[0], next_state[1] - state[1]).

    Return a list of dicts {(cell_type, delta) : prob}, one per
    time step.
    """
    T = len(observations)

    # Set up variable spaces
    action_domain = sorted(FlickeringSAREnv.ACTIONS)
    state_domain = [(r, c) for r in range(FlickeringSAREnv.GRID.shape[0]) \
                           for c in range(FlickeringSAREnv.GRID.shape[1])]
    obs_domain = state_domain + [None]

    # Random variables
    all_rvs = []
    state_vars = {}
    action_vars = {}
    obs_vars = {}
    for t in range(T):
        state_t = RV(f"state_{t}", len(state_domain))
        state_vars[t] = state_t
        all_rvs.append(state_t)
        obs_t = RV(f"obs_{t}", len(obs_domain))
        obs_vars[t] = obs_t
        all_rvs.append(obs_t)
        if t < T-1:
            action_t = RV(f"action_{t}", len(action_domain))
            action_vars[t] = action_t
            all_rvs.append(action_t)

    # Potentials
    pots = []
    obs_potentials = {}
    transition_potentials = {}
    for t in range(T):
        # Observation potential
        rvs = [obs_vars[t], state_vars[t]]
        table = np.zeros((obs_vars[t].dim, state_vars[t].dim))
        # With 1.-dark prob, observe the state
        assert obs_domain[-1] == None
        table[:-1] = (1. - dark_prob)*np.eye(state_vars[t].dim)
        # With dark prob, observe None
        table[-1] = dark_prob
        # Incorporate evidence
        table = incorporate_potential_evidence(rvs, table,
            obs_vars[t], obs_domain.index(observations[t]))
        observation_potential_t = Potential(rvs, table)
        obs_potentials[t] = observation_potential_t
        pots.append(observation_potential_t)
        if t < T - 1:
            # Transition potential
            rvs = [state_vars[t], action_vars[t], state_vars[t+1]]
            # Build up transitions
            table = np.zeros((state_vars[t].dim, action_vars[t].dim,
                              state_vars[t+1].dim))
            for i, (r, c) in enumerate(state_domain):
                cell_type = FlickeringSAREnv.GRID[r, c]
                for j, a in enumerate(action_domain):
                    for (dr, dc), p in delta_dist[cell_type][a].items():
                        # Can't go off screen
                        if not (0 <= r + dr < FlickeringSAREnv.GRID.shape[0]) or \
                           not (0 <= c + dc < FlickeringSAREnv.GRID.shape[1]):
                           continue
                        k = state_domain.index((r+dr, c+dc))
                        table[i, j, k] += p
            # Incorporate evidence
            table = incorporate_potential_evidence(rvs, table,
                action_vars[t], action_domain.index(actions[t]))
            transition_potential_t = Potential(rvs, table)
            transition_potentials[t] = transition_potential_t
            pots.append(transition_potential_t)

    # Run BP
    msgs = run_belief_prop(all_rvs, pots)

    # Get marginal distribution over consecutive states
    marginal_state_pairs = []
    for t in range(T-1):
        # Start with joint factor
        joint = transition_potentials[t].table[:, actions[t]]
        # Get forward message
        if t > 0:
            forward_msg = msgs[transition_potentials[t-1]][state_vars[t]]
            joint *= forward_msg[:, np.newaxis]
        # Get backward message
        if t < T-2:
            backward_msg = msgs[transition_potentials[t+1]][state_vars[t+1]]
            joint *= backward_msg[np.newaxis, :]
        # Get observations
        obs_t_msg = msgs[obs_potentials[t]][state_vars[t]]
        joint *= obs_t_msg[:, np.newaxis]
        obs_t1_msg = msgs[obs_potentials[t+1]][state_vars[t+1]]
        joint *= obs_t1_msg[np.newaxis, :]
        # Normalize
        joint = safe_divide(joint, joint.sum())
        marginal_state_pairs.append(joint)

    # Convert into marginal distribution over (cell type, delta)
    beliefs = []
    for t in range(T-1):
        joint = marginal_state_pairs[t]
        belief_t = defaultdict(float)
        for i, s_t in enumerate(state_domain):
            cell_type = FlickeringSAREnv.GRID[s_t[0], s_t[1]]
            for j, s_t1 in enumerate(state_domain):
                delta = (s_t1[0] - s_t[0], s_t1[1] - s_t[1])
                # If not in deltas, not possible
                if delta not in FlickeringSAREnv.DELTAS:
                    continue
                belief_t[(cell_type, delta)] += joint[i, j]
        # Renormalize due to deltas
        z = sum(belief_t.values())
        belief_t = {k : v/z for k, v in belief_t.items()}
        beliefs.append(belief_t)

    return beliefs

def incorporate_potential_evidence(rvs, table, rv, val):
    """Zero out any potential values that are inconsistent with evidence
    
    Helper for E step. Use me!

    Parameters
    ----------
    rvs : [ RV ]
        List of random variables, axes of the table.
    table : np.ndarray
        The potential values
    rv : RV
        The random variable for which we have evidence
    val : int
        The value for the evidence rv.

    Returns
    -------
    table : np.ndarray
        With inconsistent entries zeroed out.
    """
    if rv not in rvs:
        return table
    all_idxs = list(range(len(rvs)))
    all_vals = list(range(rv.dim))
    idx = rvs.index(rv)
    zero_idxs = [all_vals[:val] + all_vals[val+1:] \
                 if i == idx else slice(None) \
                 for i in all_idxs]
    table[tuple(zero_idxs)] = 0.
    assert table.sum() > 0.
    return table

### M Step

In [None]:
def compute_mle_delta_dist(marginal_seq, action_seq, laplace=0.01):
    """Helper for the M step of EM.

    Parameters
    ----------
    marginal_seq : [{(int, (int, int)) : float}]
        A sequence of marginals, where each marginal is a joint
        distribution of (cell_type, delta), represented as a dict
        so that marginal_seq[t][(cell_type, delta)] is the marginal
        probability that the state at time t had cell type cell_type
        and that the delta between times t and t+1 was delta.
    action_seq : [int]
        Sequence of actions of the same length as marginal_seq.
    laplace : float
        Additive smoothing parameter.

    Returns
    -------
    delta_dist : {int : {int : {(int, int) : float}}}
        The learned distribution.
        delta_dist[cell_type][action][delta] is the probability
        of moving delta after being in cell_type and taking action.

    Notes:
        - FlickeringSAREnv.ACTIONS are the actions
        - FlickeringSAREnv.DELTAS are all possible local moves
        - FlickeringSAREnv.GRID holds the cell type of each state
    """
    raise NotImplementedError("Implement me!")

## Pipeline

In [None]:
def create_dataset(darkness_prob, num_transitions,
                   max_transitions_per_episode, seed=0):
    """Gather a dataset of transitions

    Returns
    -------
    transitions : [([observations], [actions])]
        Observation are hashable, actions are integers.
    """
    transitions = []
    env = FlickeringSAREnv(darkness_prob=darkness_prob)
    env.seed(seed)
    done = True
    for _ in range(num_transitions):
        if done:
            episode_t = 0
            obs = env.reset()
            transitions.append([[obs], []])
        action = env.action_space.sample()
        transitions[-1][1].append(action)
        obs, _, done, _ = env.step(action)
        transitions[-1][0].append(obs)
        episode_t += 1
        if episode_t >= max_transitions_per_episode:
            done= True
    return transitions

def evaluate_model(model, verbose=False):
    """Compute the average total variational distance to ground truth
    """
    tvds = []
    for cell_type in FlickeringSAREnv.CELL_TYPES:
        for action in FlickeringSAREnv.ACTIONS:
            ground_truth = FlickeringSAREnv.get_delta_distribution(
                cell_type, action)
            tvd = 0.
            for delta in FlickeringSAREnv.DELTAS:
                gt = ground_truth.get(delta, 0.)
                md = model[cell_type][action].get(delta, 0.)
                if verbose:
                    print(f"{cell_type}, {action}, {delta}: {md} [{gt}]")
                tvd = max(tvd, abs(gt - md))
            tvds.append(tvd)
    return np.mean(tvds)

def main():
    """Gather data, learn a model, evaluate the model.
    
    This will take 5-10 minutes to run from start to finish, so you will
    probably want to mess with the hyperparameters to make it faster as you
    are developinig your code.
    """
    start_time = time.time()
    darkness_probs = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
    max_em_iters = 10
    iters_to_plot = [0, 1, 2, 3, 5, 10]
    num_trials_per_darkness_prob = 5
    num_transitions_per_trial = 300
    max_transitions_per_episode = 10

    results = {}
    for darkness_prob in darkness_probs:
        results[darkness_prob] = {i : [] for i in range(max_em_iters+1)}
        for seed in range(num_trials_per_darkness_prob):
            # Gather demonstration data (random actions)
            data = create_dataset(darkness_prob, num_transitions_per_trial,
                                  max_transitions_per_episode, seed=seed)
            # Learn a model
            em_step_to_model = learn_model(data, darkness_prob,
                max_num_iters=max_em_iters)
            for em_step, model in em_step_to_model.items():
                # Evaluate the model
                tvd = evaluate_model(model)
                # Record result
                results[darkness_prob][em_step].append(tvd)

    # Plot darkness prob versus total variational distance
    x = np.array(darkness_probs)
    plt.figure()
    plt.title("Learning Sequential Hidden State Models for S&R")
    plt.xlabel("Darkness Probability")
    plt.ylabel("Total Variational Distance from Ground Truth")
    for it in iters_to_plot:
        y_mean = np.mean([results[p][it] for p in x], axis=1)
        y_std = np.std([results[p][it] for p in x], axis=1)
        p = plt.plot(x, y_mean, marker='s', label=f'{it} EM iters')
        color = p[0].get_color()
        plt.fill_between(x, y_mean+y_std, y_mean-y_std, facecolor=color,
                         alpha=0.25)
    plt.legend()
    plt.tight_layout()
    plt.show()

    print(f"Finished run in {time.time() - start_time} seconds.")

### Fire Away

In [None]:
main()