# HW 2.3 Starter Code

See pset 1 for dependency installation instructions and see the problem set for deliverables.

As in the previous problem set, you may find this notebook useful as a reference for logical data structures: https://github.com/aimacode/aima-python/blob/master/logic.ipynb

In [None]:
# Install dependencies (run this once ever 12 hours)
!git clone https://github.com/MIT-6-882/HW utils
!pip install tabulate
!pip install python-sat

In [None]:
from utils.logic_utils import Expr, expr, FolKB, implies
from collections import defaultdict
import abc
import itertools
import heapq as hq
import time

## First-order Horn Clause Set Learning
Learn a set of first-order Horn clauses using best-first search.

In [None]:
class FOLHornClauseSetLearner(FolKB):
    """Learn a set of first-order Horn clauses, and answer queries

    Parameters
    ----------
    max_rule_size : int
        The maximum number of predicates in the body of any
        learned rule. Used to bound learning.
    max_rules : int
        The maximum number of rules that can be learned.
    max_search_iters : int
        The maximum number of nodes to expand during the search
        for a single new rule.
    terminate_early : bool
        Whether to stop learning as soon as a set of rules is
        found that perfectly cover the training examples.
    score_mode : str
        Name of the scoring heuristic used. Currently "coverage"
        is allowed.
    size_penalty_weight : float
        A weight used to regularize the size of rules, preferring
        smaller ones.
    """
    def __init__(self, max_rule_size=3, max_rules=10, max_search_iters=1000,
                 terminate_early=True, score_mode="coverage", 
                 size_penalty_weight=0.1, *args, **kwargs):
        # Note that the learned rules are stored in self.clauses, which
        # is initialized in the parent class
        super().__init__(*args, **kwargs)
        self.max_rule_size = max_rule_size
        self.max_rules = max_rules
        self.max_search_iters = max_search_iters
        self.terminate_early = terminate_early
        self.score_mode = score_mode
        self.size_penalty_weight = size_penalty_weight

    ## Inference

    def check(self, query, assumptions=None):
        """Check whether a query is entailed by the learned rules and
        optionally, additional assumptions.

        Parameters
        ----------
        query : expr
            A logical sentence to check
        assumptions : [ expr ]
            Logical sentences to add to the antecendent

        Returns
        -------
        entailed : bool
            True if query is entailed by the learned rules and
            assumptions.
        """
        # Add assumptions tentatively, then remove
        if assumptions is not None:
            for s in assumptions:
                self.tell(s)
        result = self.ask(query)
        # Remove assumptions
        if assumptions is not None:
            for s in assumptions:
                self.retract(s)
        return result

    ## Main learning function

    def train(self, training_data):
        """Learn a set of first-order Horn clauses from training data

        Training proceeds by greedily finding one rule at a time to 
        add to the overall rule set. The way that each rule is selected
        varies depending on the subclass, but always involves a search
        over rules with a heuristic scorer.

        Parameters
        ----------
        training_data : [({expr}, {expr}, {expr})]
            Each training datum is a triple of
            (input, positive outs, negative outs). For example:

            obs = {
                at("Main"),
                isYellow("Main"),
                isGreen("Park"),
                isYellow("Central"),
            }
            good_actions = { wait() }
            bad_actions = { go() }

            One datum would then be (obs, good_actions, bad_actions).
        """
        # Initialize rules
        self.clauses = []

        # Extract the predicates that are allowed in the head
        # and body of the rule
        input_predicates, output_predicates = self._extract_predicates(
            training_data)

        # Keep track of the overall score (lower is better)
        score = float("inf")

        # Repeat until all examples are covered or until max_rules
        for it in range(self.max_rules):

            # Find a new rule set
            new_rule, new_score = self._find_new_rule(training_data, 
                input_predicates, output_predicates)

            # Terminate
            if not new_rule or new_score >= score:
                break

            score = new_score

            # Add new rule to clauses
            new_clause = self._rule_to_clause(new_rule)
            assert new_clause not in self.clauses, \
                "Tried to add a rule that is already in the rule set"
            self.tell(new_clause)

            # Check whether we're done
            if self._all_examples_covered(training_data):
                print(f"Training finished after {it+1} iterations.")
                break
        else:
            print("Training did not converge, giving up.")

        print("Final rules:\n ", "\n  ".join(map(str, self.clauses)))

    ## Learning a single rule

    def _find_new_rule(self, training_data, input_predicates, output_predicates):
        """Propose a rule to add to the rule set.

        Works by performing a search, where the state of the search is a single
        candidate rule (Horn clause). The initialization of the search and the
        successor function are what distinguish subclasses.

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        training_data : [({expr}, {expr}, {expr})]
            See self.train docstring
        input_predicates : { (str, int) }
            A set of predicates, where each predicate is a
            (name, arity) [recall "arity" means # of args].
            These are the predicates that can be involved
            in the body of the rules.
        output_predicates : { (str, int) }
            These are the predicates that can be involved
            in the head of the rules.

        Returns
        -------
        new_rule : (frozenset(expr), expr)
            See above
        score : float
            Used to determine whether to terminate the outer search.
            Lower is better.
        """
        # Initialize search
        tiebreak = itertools.count()
        queue = []
        best_score = float("inf")
        best_rule = None
        visited = set()

        for rule in self._get_initial_rules(training_data, output_predicates):
            hq.heappush(queue, (-float("inf"), next(tiebreak), rule))

        # Run the search
        for it in range(self.max_search_iters):
            print(f"Iteration {it}", end='\r')
            # Check if we've exhausted the queue
            if len(queue) == 0:
                break
            # Get a rule to extend
            _, _, rule = hq.heappop(queue)
            # Consider different extensions
            for child in self._get_successors(rule, input_predicates):
                # No need to reconsider children
                if child in visited:
                    continue
                # Don't consider children that are too large
                if len(child[0]) > self.max_rule_size:
                    continue
                # Score the child
                child_score = self._score_rule(child, training_data)
                # Update best score
                if child_score < best_score and \
                    not self._rule_is_malformed(child):
                    best_score = child_score
                    best_rule = child
                    print("Updating best rule:", best_rule)
                    # Perfect fit, terminate early
                    if self.terminate_early and child_score == float("-inf"):
                        print("\nFound perfect rule:")
                        print(best_rule)
                        return best_rule, best_score
                # Add to the queue
                hq.heappush(queue, (child_score, next(tiebreak), child))
                visited.add(child)

        print("\nTerminated without a perfect rule")
        if best_score < 0: # negative means some improvemment
            print(f"Best rule found (score {best_score}):")
            print(best_rule)
            return best_rule, best_score
        print("Could not find any rule that improves overall")
        return False

    def _score_rule(self, rule, training_data):
        """Give a heuristic score to a candidate rule that we are
        considering adding to our overall rule set.

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        rule : (frozenset(expr), expr)
            See above
        training_data : [({expr}, {expr}, {expr})]
            See self.train docstring

        Returns
        -------
        score : float
            Lower is better
        """
        if self.score_mode == "coverage":
            score = self._score_rule_with_coverage(rule, training_data)
        else:
            raise Exception(f"Unrecognized score mode {mode}")

        # Add size penalty
        rule_size = len(rule[0])
        score += self.size_penalty_weight*rule_size

        return score

    def _score_rule_with_coverage(self, rule, training_data):
        """Count the number of negative and positive examples
        covered by the rule in the training data, and give a score
        of simply # negatives - # positives.
        """
        # If the rule is malformed, it explains nothing
        if self._rule_is_malformed(rule):
            return 0

        # We will count the number of correct overall to see
        # if we should terminate early
        all_correct = True
        num_true_positives = 0
        num_false_positives = 0
        # Convert from manipulable representation to AIMA expr
        # so that we can call self.check
        new_clause = self._rule_to_clause(rule)
        for x, pos_y, neg_y in training_data:
            assumptions = [new_clause] + list(x)
            for y in pos_y:
                # Does the new clause and the training data
                # (together with the old clauses) imply the
                # positive example?
                if self.check(y, assumptions=assumptions):
                    num_true_positives += 1
                else:
                    all_correct = False
            for y in neg_y:
                # Does the new clause and the training data
                # (together with the old clauses) imply the
                # negative example?
                if self.check(y, assumptions=assumptions):
                    num_false_positives += 1
                    all_correct = False
        score = num_false_positives - num_true_positives
        # Check if perfect
        if self.terminate_early and all_correct:
            return -float("inf")
        return score

    def _get_possible_groundings(self, predicate, available_variables):
        """Get all possible groundings of the predicate over variables.
        Allow groundings over the available variables, or over fresh vars.

        Helper for learning new rules.

        Parameters
        ----------
        predicate : (str, int)
            See self._extract_predicates
        available_variables : { int }
            See self._expr_to_variables

        Yields
        ------
        ground_predicate : expr
            See self._ground_predicate
        """
        _, arity = predicate
        max_var = 0 if len(available_variables) == 0 else max(available_variables)
        available_variables = sorted(available_variables)

        for num_new_variables in range(arity+1):
            # Allow groundings over "fresh" variables
            new_vars = list(range(max_var+1, max_var+1+num_new_variables))
            for choice in itertools.product(available_variables+new_vars, 
                                            repeat=arity):
                # Skip if new vars are not in order
                new_vars_in_choice = [v for v in choice if v in new_vars]
                if new_vars_in_choice != sorted(new_vars_in_choice):
                    continue
                # Skip if min new var is not min
                if len(new_vars_in_choice) and \
                    (min(new_vars_in_choice) != min(new_vars)):
                    continue
                yield self._ground_predicate(predicate, choice)

    def _all_examples_covered(self, training_data):
        """Check whether all training data are covered (we'll stop if so)

        Parameters
        ----------
        training_data : [({expr}, {expr}, {expr})]
            See self.train docstring

        Returns
        -------
        covered : bool
        """
        for x, pos_y, neg_y in training_data:
            for y in pos_y:
                # Does x imply y (with the rules)?
                if not self.check(y, assumptions=list(x)):
                    return False
            for y in neg_y:
                # Does x imply y (with the rules)?
                if self.check(y, assumptions=list(x)):
                    return False
        return True

    ## Helpers for converting between AIMA and more manipulable representations

    def _extract_predicates(self, training_data):
        """Identify the predicates in the inputs and outputs of the data

        Parameters
        ----------
        training_data : [({expr}, {expr}, {expr})]
            See self.train docstring

        Returns
        -------
        input_predicates : { (str, int) }
            A set of predicates, where each predicate is a
            (name, arity) [recall "arity" means # of args].
            These are the predicates that can be involved
            in the body of the rules.
        output_predicates : { (str, int) }
            These are the predicates that can be involved
            in the head of the rules.
        """
        input_predicates = set()
        output_predicates = set()

        for x, pos_y, neg_y in training_data:
            # Input predicates
            for x_i in x:
                predicate = self._expr_to_predicate(x_i)
                input_predicates.add(predicate)
            # Output predicates
            for y_i in pos_y | neg_y:
                predicate = self._expr_to_predicate(y_i)
                output_predicates.add(predicate)

        return input_predicates, output_predicates

    def _expr_to_predicate(self, expr):
        """Extract a predicate from a AIMA logical expr.

        For example, if the expr is Green(Main), this
        will return ("Green", 1), where the first element
        is the name of the predicate and the second element
        is the arity, i.e., the number of arguments to the
        predicate.

        Parameters
        ----------
        expr : expr

        Returns
        -------
        predicate : (str, int)
        """
        arity = len(expr.args)
        name = expr.op
        return (name, arity)

    def _expr_to_variables(self, expr):
        """Extract all of the variables from an AIMA logical expr
        representing a predicate.

        Note that variables are represented as integers during
        learning.

        For example, if the expr is Green(x5), this
        will return {5}.

        Parameters
        ----------
        expr : expr

        Returns
        -------
        variables : { int }
        """
        variables = { int(str(a)[1:]) for a in expr.args }
        return variables

    def _ground_predicate(self, predicate, variables):
        """Create an AIMA logical expr for a ground predicate
        given the predicate and variables.

        For example, if predicate is ("green", 1) and the
        variables are [5], then this will return the expr
        Green(x5).

        Parameters
        ----------
        predicate : (str, int)
            See _expr_to_predicate.
        variables : [ int ]
            See _expr_to_variables.

        Returns
        -------
        ground_predicate : expr
        """
        name, arity = predicate
        assert len(variables) == arity
        args = [f"x{i}" for i in variables]
        return expr(str(name) + "(" + ",".join(args) + ")")

    def _rule_to_clause(self, rule):
        """Convert from manipulable rule into AIMA expr.

        For example, if rule is ({Green(Main), Yellow(Park)}, Go()),
        then this will return Green(Main) & Yellow(Park) ==> Go().
        """
        body, head = rule
        query = implies(Expr("&", *body), head)
        return query

    ## Other helper methods

    def _rule_is_malformed(self, rule):
        """A rule is malformed if not all head variables appear in the body
        """
        body_vars = {v for a in rule[0] for v in self._expr_to_variables(a)}
        head_vars = self._expr_to_variables(rule[1])
        return not head_vars.issubset(body_vars)

    def _subst(self, subs, x):
        """Substitute into an expression x

        Parameters
        ----------
        subs : { expr : expr }
            Substitute keys to values.
        x : expr

        Returns
        -------
        x' : expr
        """
        if isinstance(x, list):
            return [self._subst(subs, i) for i in x]
        if isinstance(x, frozenset):
            return frozenset({self._subst(subs, i) for i in x})
        if isinstance(x, set):
            return {self._subst(subs, i) for i in x}
        if x in subs:
            return subs[x]
        if hasattr(x, "args"):
            new_args = [self._subst(subs, a) for a in x.args]
            return Expr(x.op, *new_args)
        raise Exception("Substitution failed")

### Top-Down Learner

In [None]:
class TopDownLearner(FOLHornClauseSetLearner):
    """A top-down FOL horn clause set learner.

    It is top-down because it searches over individual rules by
    starting with empty rules and gradually adds more and more
    clauses.
    """
    def _get_initial_rules(self, _, output_predicates):
        """Get rules to initialize the search for a new rule

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        output_predicates : { (str, int) }
            These are the predicates that can be involved
            in the head of the rules.

        Yields
        ------
        new_rule : (frozenset(expr), expr)
            See above
        """
        # Start with the output predicates alone
        for output_predicate in output_predicates:
            _, predicate_arity = output_predicate
            consequent = self._ground_predicate(output_predicate,
                range(predicate_arity))
            yield (frozenset(), consequent)

    def _get_successors(self, rule, input_predicates):
        """Consider adding new single predicates to the rule body

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        rule : (frozenset(expr), expr)
            See above
        input_predicates : { (str, int) }
            A set of predicates, where each predicate is a
            (name, arity) [recall "arity" means # of args].
            These are the predicates that can be involved
            in the body of the rules.

        Yields
        ------
        new_rule : (frozenset(expr), expr)
            See above
        """
        body, head = rule
        # Get the variables that are already in the rule
        available_variables = self._expr_to_variables(head)
        for atom in body:
            available_variables.update(self._expr_to_variables(atom))
        # Consider predicates to add
        for predicate in input_predicates:
            # Consider groundings of the predicate
            for candidate in self._get_possible_groundings(predicate, 
                available_variables):
                # Don't add if already in
                if candidate in body:
                    continue
                yield (frozenset(list(body) + [candidate]), head)

### Bottom-Up Learner

In [None]:
class BottomUpLearner(FOLHornClauseSetLearner):
    """
    """
    def __init__(self, *args, **kwargs):
        # We're going bottom up, so we don't need to specify a max rule
        # size. We also don't want to terminate early because we may
        # perfectly fit the data but want to still prune the rules.
        super().__init__(max_rule_size=float("inf"), terminate_early=False,
            *args, **kwargs)

    def _get_initial_rules(self, training_data, _):
        """Get rules to initialize the search for a new rule

        For bottom-up learning, these rules are just the data, but
        lifted, i.e., substituing the objects with variables.

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        training_data : [({expr}, {expr}, {expr})]
            See self.train docstring

        Yields
        ------
        new_rule : (frozenset(expr), expr)
            See above
        """
        proposals = set()
        # Propose one rule per positive training example
        # by "lifting" the constants
        for x, pos_y, _ in training_data:
            for y in pos_y:
                constants = set(y.args)
                for x_i in x:
                    constants.update(x_i.args)
                constant_to_var = { c : expr(f"x{i+1}") for i, c in \
                                    enumerate(sorted(constants)) }
                new_x = frozenset(self._subst(constant_to_var, x))
                new_y = self._subst(constant_to_var, y)
                proposals.add((new_x, new_y))
        return sorted(proposals)

    def _get_successors(self, rule, _):
        """Consider removing single predicates from the rule body

        Rules are represented with as a tuple (body, head) where body is
        a frozenset containing the expr's that are conjoined in the body,
        and head is the single expr representing the consequent of the Horn clause.

        Parameters
        ----------
        rule : (frozenset(expr), expr)
            See above

        Yields
        ------
        new_rule : (frozenset(expr), expr)
            See above
        """
        # Propose children by deleting predicates from the body
        rule_size = len(rule[0])
        head = rule[1]
        for i in range(rule_size):
            body = tuple([e for j, e in enumerate(rule[0]) if i != j])
            new_rule = (body, head)
            yield new_rule

### Utilities

In [None]:
def create_ilp_model(method):
    """Create an ILP model by name
    """
    if method == "top_down":
        return TopDownLearner()
    if method == "bottom_up":
        return BottomUpLearner()
    raise Exception(f"Unknown ILP method {method}")
    
def _test_ilp_helper(training_data, test_data, methods=("top_down", "bottom_up")):
    """Helper for testing ILP methods
    """
    start_time = time.time()

    for method in methods:
        print(f"Testing method {method}")
        # Initialize
        model = create_ilp_model(method)
        # Train
        model.train(training_data)
        # Test
        for obs, good_outputs, bad_outputs in test_data:
            for y in good_outputs:
                assert model.check(y, assumptions=list(obs)), \
                    f"Model checking failed for positive example {obs, y}"
            for y in bad_outputs:
                assert not model.check(y, assumptions=list(obs)), \
                    f"Model checking failed for negative example {obs, y}"

    print(f"Test passed in {time.time()-start_time}")

## Demos

In [None]:
def test_traffic_light():
    """Learn a rule for what to do at a traffic light

    (We're basically making self-driving cars!!!)

    Target rules
    ------------
    at(x) & isGreen(x) => go()
    at(x) & isYellow(x) => wait()
    at(x) & isRed(x) => wait()
    """
    def at(x):
        return expr(f"At({x})")

    def isGreen(x):
        return expr(f"IsGreen({x})")

    def isYellow(x):
        return expr(f"IsYellow({x})")

    def isRed(x):
        return expr(f"IsRed({x})")

    def go():
        return expr(f"Go()")

    def wait():
        return expr(f"Wait()")

    # Observations, positive actions, negative actions
    training_data = []
    obs = {
        at("Main"),
        isGreen("Main"),
        isYellow("Park"),
        isRed("Central"),
    }
    good_actions = { go() }
    bad_actions = { wait() }
    training_data.append((obs, good_actions, bad_actions))
    
    obs = {
        at("Main"),
        isYellow("Main"),
        isYellow("Park"),
        isRed("Central"),
    }
    good_actions = { wait() }
    bad_actions = { go() }
    training_data.append((obs, good_actions, bad_actions))

    obs = {
        at("Main"),
        isYellow("Main"),
        isGreen("Park"),
        isYellow("Central"),
    }
    good_actions = { wait() }
    bad_actions = { go() }
    training_data.append((obs, good_actions, bad_actions))

    obs = {
        at("Main"),
        isGreen("Main"),
        isGreen("Park"),
        isYellow("Central"),
    }
    good_actions = { go() }
    bad_actions = { wait() }
    training_data.append((obs, good_actions, bad_actions))

    obs = {
        at("Main"),
        isRed("Main"),
        isGreen("Park"),
        isYellow("Central"),
    }
    good_actions = { wait() }
    bad_actions = { go() }
    training_data.append((obs, good_actions, bad_actions))

    test_data = []

    obs = {
        at("Glen"),
        isGreen("Glen"),
    }
    good_actions = { go() }
    bad_actions = { wait() }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        at("Glen"),
        isYellow("Glen"),
    }
    good_actions = { wait() }
    bad_actions = { go() }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        at("Glen"),
        isRed("Glen"),
    }
    good_actions = { wait() }
    bad_actions = { go() }
    test_data.append((obs, good_actions, bad_actions))

    _test_ilp_helper(training_data, test_data)


def test_search_and_rescue_policy_learning():
    """Learn a policy for search and rescue from positive and negative
    examples of each action.

    Target rules
    ------------
    personAt(x) & robotNotAtPerson() & handEmpty() => goTo(x)
    personAt(x) & robotAt(x) & handEmpty() => pickUp()
    hospitalAt(x) & robotNotAtHospital() & handFull() => goTo(x)
    hospitalAt(x) & robotAt(x) & handFull() => putDown()
    """
    def personAt(x):
        return expr(f"PersonAt({x})")

    def robotAt(x):
        return expr(f"RobotAt({x})")

    def robotNotAtPerson():
        return expr(f"RobotNotAtPerson()")

    def robotNotAtHospital():
        return expr(f"RobotNotAtHospital")

    def hospitalAt(x):
        return expr(f"HospitalAt({x})")

    def handEmpty():
        return expr(f"HandEmpty()")

    def handFull():
        return expr(f"HandFull()")

    def goTo(x):
        return expr(f"GoTo({x})")

    def pickUp():
        return expr(f"PickUp()")

    def putDown():
        return expr(f"PutDown()")

    # Observations, positive actions, negative actions

    training_data = []
    # Time 0: there's people at locations (0, 0) and (0, 3),
    # the robot is at (0, 1) with empty hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f0"),
        personAt("L0f3"),
        robotNotAtPerson(),
        robotNotAtHospital(),
        robotAt("L0f1"),
        hospitalAt("L0f2"),
        handEmpty(),
    }
    good_actions = { goTo("L0f0") }
    bad_actions = {
        goTo("L0f1"),
        goTo("L0f2"),
        # Importantly, do not include goTo("L0f3")
        # as a negative example, because it would work
        # just as well as the positive example
        pickUp(),
        putDown(),
    }
    training_data.append((obs, good_actions, bad_actions))
    
    # Time 1: there's people at locations (0, 0) and (0, 3),
    # the robot is at (0, 0) with empty hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f0"),
        personAt("L0f3"),
        robotAt("L0f0"),
        hospitalAt("L0f2"),
        robotNotAtHospital(),
        handEmpty(),
    }
    good_actions = { pickUp() }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f2"),
        goTo("L0f3"),
        putDown(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 2: there's a person at location (0, 3),
    # the robot is at (0, 0) with full hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f3"),
        robotAt("L0f0"),
        robotNotAtPerson(),
        robotNotAtHospital(),
        hospitalAt("L0f2"),
        handFull(),
    }
    good_actions = { goTo("L0f2") }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f3"),
        pickUp(),
        putDown(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 3: there's a person at location (0, 3),
    # the robot is at (0, 2) with full hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f3"),
        robotNotAtPerson(),
        robotAt("L0f2"),
        hospitalAt("L0f2"),
        handFull(),
    }
    good_actions = { putDown() }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f2"),
        goTo("L0f3"),
        pickUp(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 4: there's a person at location (0, 3),
    # the robot is at (0, 2) with empty hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f3"),
        robotAt("L0f2"),
        robotNotAtPerson(),
        hospitalAt("L0f2"),
        handEmpty(),
    }
    good_actions = { goTo("L0f3") }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f2"),
        pickUp(),
        putDown(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 5: there's a person at location (0, 3),
    # the robot is at (0, 3) with empty hands
    # and the hospital is at (0, 2).
    obs = {
        personAt("L0f3"),
        robotAt("L0f3"),
        hospitalAt("L0f2"),
        robotNotAtHospital(),
        handEmpty(),
    }
    good_actions = { pickUp() }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f2"),
        goTo("L0f3"),
        putDown(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 6: there's a person at location (0, 3),
    # the robot is at (0, 3) with full hands
    # and the hospital is at (0, 2).
    obs = {
        robotAt("L0f3"),
        hospitalAt("L0f2"),
        robotNotAtHospital(),
        handFull(),
    }
    good_actions = { goTo("L0f2") }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f3"),
        putDown(),
        pickUp(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # Time 7: there's a person at location (0, 3),
    # the robot is at (0, 2) with full hands
    # and the hospital is at (0, 2).
    obs = {
        robotAt("L0f2"),
        hospitalAt("L0f2"),
        handFull(),
    }
    good_actions = { putDown() }
    bad_actions = {
        goTo("L0f0"),
        goTo("L0f1"),
        goTo("L0f2"),
        goTo("L0f3"),
        pickUp(),
    }
    training_data.append((obs, good_actions, bad_actions))

    # These bad actions are not exhaustive, just a few to test
    test_data = []
    obs = {
        personAt("L10f10"),
        personAt("L11f10"),
        personAt("L12f10"),
        hospitalAt("L5f5"),
        robotNotAtHospital(),
        robotAt("L0f0"),
        robotNotAtPerson(),
        handEmpty(),
    }
    good_actions = { goTo("L10f10"), goTo("L11f10"), goTo("L12f10") }
    bad_actions = { putDown(), pickUp() }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        personAt("L10f10"),
        personAt("L11f10"),
        personAt("L12f10"),
        hospitalAt("L5f5"),
        robotNotAtHospital(),
        robotAt("L10f10"),
        handEmpty(),
    }
    good_actions = { pickUp() }
    bad_actions = { putDown(), goTo("L11f10"), goTo("L5f5") }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        personAt("L11f10"),
        personAt("L12f10"),
        hospitalAt("L5f5"),
        robotAt("L10f10"),
        robotNotAtHospital(),
        handFull(),
    }
    good_actions = { goTo("L5f5") }
    bad_actions = { pickUp(), putDown(), goTo("L11f10") }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        personAt("L11f10"),
        personAt("L12f10"),
        hospitalAt("L5f5"),
        robotNotAtPerson(),
        robotAt("L5f5"),
        handFull(),
    }
    good_actions = { putDown() }
    bad_actions = { pickUp(), goTo("L11f10") }
    test_data.append((obs, good_actions, bad_actions))

    obs = {
        personAt("L11f10"),
        personAt("L12f10"),
        hospitalAt("L5f5"),
        robotNotAtPerson(),
        robotAt("L5f5"),
        handEmpty(),
    }
    good_actions = { goTo("L11f10"), goTo("L12f10") }
    bad_actions = { pickUp(), putDown() }
    test_data.append((obs, good_actions, bad_actions))

    _test_ilp_helper(training_data, test_data)
    
def my_new_test():
    """Add a new test here!
    """
    raise NotImplementedError("Implement me!")

### Fire Away

In [None]:
test_traffic_light()
test_search_and_rescue_policy_learning()
my_new_test()