# HW 3.3 Solutions

See pset for deliverables.

In [None]:
from collections import defaultdict, namedtuple
from tabulate import tabulate
import numpy as np
import time
np.seterr(divide='ignore', invalid='ignore')

## Data Structures

In [None]:
RV = namedtuple("RV", ["name", "dim"])
Potential = namedtuple("Potential", ["rvs", "table"])
Potential.__hash__ = lambda self: hash(tuple(self.rvs)) ^ hash(self.table.tobytes())
InferenceProblem = namedtuple("InferenceProblem", ["rvs", "potentials", "query", "evidence"])

def validate_inference_problem(problem, potentials_are_cpts=False):
    """Validate an inference problem

    Parameters
    ----------
    problem : InferenceProblem
    potentials_are_cpts : bool

    Raises
    ------
    AssertionError
        If not valid
    """
    assert len(problem.potentials) > 0, "Inference problem must have Potentials"
    all_pot_rvs = set()
    for pot in problem.potentials:
        for i, rv in enumerate(pot.rvs):
            assert rv.dim == pot.table.shape[i], \
                f"Potential table dim {i} does not match RV {rv} dimension {rv.dim}"
        if potentials_are_cpts:
            assert np.allclose(pot.table.sum(axis=0), 1.), "CPT probs must sum to 1 over first axis"
        all_pot_rvs.update(pot.rvs)
    assert set(problem.rvs) == all_pot_rvs, "Potential RVs must be exactly the RVs in the problem"
    assert all(rv.dim > 1 for rv in problem.rvs), "All RVs should have nontrivial dimension"
    assert isinstance(problem.query, dict), "Problem query must be a dict"
    assert all(k in problem.rvs for k in problem.query), "Query keys must be RVs"
    assert all(v in range(k.dim) for k, v in problem.query.items()), "Query values must be RV values"
    assert isinstance(problem.evidence, dict), "Problem evidence must be a dict"
    assert all(k in problem.rvs for k in problem.evidence), "Evidence keys must be RVs"
    assert all(v in range(k.dim) for k, v in problem.evidence.items()), "Evidence values must be RV values"
    assert set(problem.query.keys()) & set(problem.evidence.keys()) == set(), "Query and evidence RVs must be disjoint"

## Inference Methods

### Brute Force

In [None]:
def run_brute_force_inference(problem, max_num_rows=2**25):
    """Run inference by creating the full joint distribution
    and then marginalizing.

    Parameters
    ----------
    problem : InferenceProblem
    max_num_rows : int
        If there are to be more than this number of rows
        in the joint table, terminate early with no answer.

    Returns
    -------
    result : float
        Answer to the query in the problem
    """
    # Give up right away if the problem is too large
    total_rows = np.sum([np.log(rv.dim) for rv in problem.rvs])
    if total_rows > np.log(max_num_rows):
        print("Problem too large! Brute force inference terminating")
        return None

    # Create joint table
    einsum_args = []
    for pot in problem.potentials:
        einsum_args.append(pot.table)
        einsum_args.append([problem.rvs.index(rv) for rv in pot.rvs])
    einsum_args.append(list(range(len(problem.rvs))))
    joint_table = np.einsum(*einsum_args)
    joint_table = joint_table / joint_table.sum()

    # Compute marginal for evidence and query
    evidence_and_query = {**problem.query, **problem.evidence}
    idxs = [evidence_and_query.get(rv, slice(None)) for rv in problem.rvs]
    p_query_and_evidence = joint_table[tuple(idxs)].sum()

    # Compute marginal for evidence alone
    idxs = [problem.evidence.get(rv, slice(None)) for rv in problem.rvs]
    p_evidence = joint_table[tuple(idxs)].sum()

    assert p_evidence > 0, "Evidence has zero probability"

    return p_query_and_evidence / p_evidence

### (Loopy) Belief Propagation (aka Sum-Product)

In [None]:
def run_belief_prop(problem, max_iters=1000):
    """Run inference with belief propagation.

    Parameters
    ----------
    problem : InferenceProblem
    max_iters : int
        Maximum number of iterations for each BP call.

    Returns
    -------
    result : float
        Answer to the query in the problem
    """
    raise NotImplementedError("Implement me!")

## Problems

In [None]:
def create_debug_2vars_problem(version):
    """A simple problem with two random variables
    """
    A = RV("A", 2)
    B = RV("B", 3)
    rvs = [A, B]
    p_a_given_b = Potential([A, B], np.array([
        [0.9, 0.05, 0.05],
        [0.05, 0.9, 0.05],
    ]))
    p_b = Potential([B], np.array([0.7, 0.2, 0.1]))
    pots = [p_a_given_b, p_b]
    if version == 1:
        query = {A : 1}
        evidence = {B : 1}
    elif version == 2:
        query = {B : 1}
        evidence = {A : 1}        
    else:
        assert version == 3
        query = {A : 1, B : 1}
        evidence = {}
    return InferenceProblem(rvs, pots, query, evidence)

def create_california_problem(version):
    """Holmes, watson, earthquakes, radios, oh my...
    """
    p_b = np.array([0.99, 0.01])
    p_e = np.array([0.97, 0.03])
    p_re = np.array([
        [0.98, 0.01],
        [0.02, 0.99],
    ])
    p_aeb = np.zeros((2, 2, 2))
    p_aeb[1, 0, 0] = 0.01
    p_aeb[0, 0, 0] = 1. - 0.01
    p_aeb[1, 0, 1] = 0.2
    p_aeb[0, 0, 1] = 1. - 0.2
    p_aeb[1, 1, 0] = 0.95
    p_aeb[0, 1, 0] = 1. - 0.95
    p_aeb[1, 1, 1] = 0.96
    p_aeb[0, 1, 1] = 1. - 0.96

    A = RV("Alarm", 2)
    B = RV("Burglar", 2)
    E = RV("Earthquake", 2)
    R = RV("Radio", 2)
    rvs = [A, B, E, R]
    pots = [
        Potential([B], p_b),
        Potential([E], p_e),
        Potential([R, E], p_re),
        Potential([A, E, B], p_aeb)
    ]
    if version == "alarm":
        # P(B=1 | A=1)
        query = {B : 1}
        evidence = {A : 1}
    else:
        assert version == "alarm and earthquake"
        # P(B=1 | A=1, R=1)
        query = {B : 1}
        evidence = {A : 1, R : 1}
    return InferenceProblem(rvs, pots, query, evidence)

def create_binary_chain_problem(num_vars):
    """A simple binary chain designed to stress test inference
    """
    rvs = [RV(f"X{i}", 2) for i in range(num_vars)]
    pots = []
    for rv_t, rv_t1 in zip(rvs[:-1], rvs[1:]):
        pot = Potential([rv_t, rv_t1], np.array([
            [0.9, 0.1],
            [0.1, 0.9],
        ]))
        pots.append(pot)
    query = {rvs[0] : 0}
    evidence = {}        
    return InferenceProblem(rvs, pots, query, evidence)


def create_sar_fires_problem(grid, queries):
    """Infer whether there is fire based on smoke and fire observations.
    """
    height, width = len(grid), len(grid[0])
    # Create RVs
    fires = defaultdict(dict)
    smokes = defaultdict(dict)
    rvs = []
    for r in range(height):
        for c in range(width):
            fire_rv = RV(f"fire({r},{c})", 2)
            smoke_rv = RV(f"smoke({r},{c})", 2)
            fires[r][c] = fire_rv
            smokes[r][c] = smoke_rv
            rvs.extend([fire_rv, smoke_rv])

    # Create potentials
    pots = []
    for r in range(height):
        for c in range(width):
            fire = fires[r][c]
            # Give a low prior to fires
            pots.append(Potential([fire], np.array([0.9, 0.1])))
            # If there is fire, there must be adjacent smokes
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nr = r + dr
                nc = c + dc
                if 0 <= nr < height and 0 <= nc < width:
                    smoke = smokes[nr][nc]
                    pot = Potential([fire, smoke], np.array([
                        [0.5, 0.5],
                        [0.0, 1.0],
                    ]))
                    pots.append(pot)

            # If there is smoke, then there has to be fire somewhere
            smoke = smokes[r][c]
            neighbor_fires = []
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nr = r + dr
                nc = c + dc
                if 0 <= nr < height and 0 <= nc < width:
                    neighbor_fires.append(fires[nr][nc])
            table = np.ones([2 for _ in [smoke] + neighbor_fires])
            table[tuple([1] + [0 for _ in neighbor_fires])] = 0.
            pot = Potential([smoke] + neighbor_fires, table)
            pots.append(pot)

    # Create query
    query = {}
    for x, (r, c) in queries:
        if x == "F":
            query[fires[r][c]] = 1
        elif x == "S":
            query[smokes[r][c]] = 1
        else:
            assert x == "E"
            query[fires[r][c]] = 0
            query[smokes[r][c]] = 0

    # Create evidence
    evidence = {}
    for r in range(height):
        for c in range(width):
            if grid[r][c] == "F":
                evidence[fires[r][c]] = 1
            elif grid[r][c] == "S":
                evidence[smokes[r][c]] = 1
            elif grid[r][c] == "E":
                evidence[fires[r][c]] = 0
                evidence[smokes[r][c]] = 0
            else:
                assert grid[r][c] == "?"

    return InferenceProblem(rvs, pots, query, evidence)

## Main

In [None]:
def main():
    # Create inference problems
    problems = {
        'Debug 2 Vars 1' : create_debug_2vars_problem(1),
        'Debug 2 Vars 2' : create_debug_2vars_problem(2),
        'Debug 2 Vars 3' : create_debug_2vars_problem(3),
        'Burglar Given Alarm' : create_california_problem("alarm"),
        'Burglar Given Alarm and Earthquake' : create_california_problem("alarm and earthquake"),
        'Binary Chain (# Vars = 005)' : create_binary_chain_problem(5),
        'Binary Chain (# Vars = 010)' : create_binary_chain_problem(10),
        'Binary Chain (# Vars = 015)' : create_binary_chain_problem(15),
        'Binary Chain (# Vars = 020)' : create_binary_chain_problem(20),
        'Binary Chain (# Vars = 025)' : create_binary_chain_problem(25),
        'Binary Chain (# Vars = 050)' : create_binary_chain_problem(50),
        'Binary Chain (# Vars = 100)' : create_binary_chain_problem(100),
        'Binary Chain (# Vars = 1000)' : create_binary_chain_problem(1000),
        'SAR Fires 1x2-0' : create_sar_fires_problem([["?", "?"]], [("F", (0, 0)), ("E", (0, 1))]),
        'SAR Fires 1x5-1' : create_sar_fires_problem([["?", "S", "E", "?", "?"]], [("F", (0, 0))]),
        'SAR Fires 1x5-2' : create_sar_fires_problem([["?", "S", "E", "?", "?"]], [("F", (0, 3))]),
        'SAR Fires 1x5-3' : create_sar_fires_problem([["?", "S", "E", "?", "?"]], [("F", (0, 4))]),
        'SAR Fires 3x5-1' : create_sar_fires_problem([["?", "?", "E", "?", "?"],
                                                      ["?", "?", "S", "E", "?"],
                                                      ["?", "?", "?", "?", "?"]], [("F", (2, 1))]),
        'SAR Fires 3x5-1' : create_sar_fires_problem([["?", "?", "E", "?", "?"],
                                                      ["?", "?", "S", "E", "?"],
                                                      ["?", "?", "?", "?", "?"]], [("F", (0, 0))]),
        'SAR Fires 5x5-1' : create_sar_fires_problem([["?", "?", "?", "?", "?"],
                                                      ["?", "?", "E", "?", "?"],
                                                      ["E", "S", "?", "?", "?"],
                                                      ["?", "?", "S", "?", "?"],
                                                      ["?", "?", "?", "?", "?"]], [("F", (3, 1))]),
    }
    for problem in problems.values():
        validate_inference_problem(problem)

    # Create inference methods
    inference_methods = {
        'Brute' : run_brute_force_inference,
        'BP' : run_belief_prop,
    }

    # Evaluate each inference method in each problem
    all_times = {}
    all_answers = {}
    for problem_name, problem in problems.items():
        all_times[problem_name] = {}
        all_answers[problem_name] = {}
        for method_name, method in inference_methods.items():
            # Run inference
            start_time = time.time()
            answer = method(problem)
            duration = time.time() - start_time
            all_times[problem_name][method_name] = duration
            all_answers[problem_name][method_name] = answer

    # Tabulate and print results
    headers = ["Problem", "Method", "Answer", "Time (s)"]
    table = []
    for problem_name in sorted(problems):
        for method_name in sorted(inference_methods):
            table.append((problem_name, method_name,
                          all_answers[problem_name][method_name],
                          all_times[problem_name][method_name]))
    print(tabulate(table, headers=headers))

In [None]:
main()