# Project 4: Semantic Parsing

This project will have you implement a neural semantic parser for the GeoQA dataset of [Krishnamurthy and Kollar, 2013](http://rtw.ml.cmu.edu/tacl2013_lsp/tacl2013-krishnamurthy-kollar.pdf), which consists of a database of simple geographic facts about 10 US states, questions and answers about the database, and annotated logical forms for the questions. Your final system will go from natural language questions to their answers, computed from the database, via logical forms that are executed on the database. 

First, you'll implement a method for executing logical forms on the database. Then, you'll implement some components of a constrained sequence-to-sequence model for producing logical forms from questions. You will train it using paired questions and logical forms. Then, you will train it from questions-answer pairs, by searching over latent logical forms.

Note: this dataset is small enough that we will be able to train on CPU; you don't need a GPU instance.

## Setup

The dependencies for this project include:
* `torch` for modeling and training
* `sexpdata` for loading logical forms from [S-expressions](https://en.wikipedia.org/wiki/S-expression)
* `geoqa`: support code for loading and preprocessing the GeoQA database, evaluating predicted answers, and support code, available at this [github repo](https://github.com/dpfried/geoqa-release) (although you can treat the code as a black box and not worry about the details).

In [0]:
%%capture
!pip install --upgrade torch tqdm sexpdata
# this provides the packages available here: https://github.com/dpfried/geoqa-release
!wget https://github.com/dpfried/geoqa-release/archive/master.zip -O geoqa.zip
!unzip geoqa.zip
!rm geoqa.zip
!mv geoqa-release-master/data.tgz .
!mv geoqa-release-master/geoqa .
!mv geoqa-release-master/vecs .
!rm -r geoqa-release-master
!tar xf data.tgz

# Standard library imports
import math
import random
import pprint
import pickle
from typing import Union, List, Set, Tuple
from collections import namedtuple

# Third party imports
import editdistance
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import torch.nn as nn
import tqdm.notebook
import sexpdata

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# geoqa imports; we'll examine these shortly
import geoqa
import geoqa.geo
import geoqa.utils

from geoqa.utils import Stack, Index, logical_form_to_str
from geoqa.geo import STATES, World
from geoqa.geo import CATS, RELS, ENTITY_ACTIONS_TO_ENTITIES
from geoqa.geo import LAMBDA, EXISTS, AND, VARS, MAX_LITERALS
from geoqa.dataset import GeoDataset
from geoqa.evaluation import evaluate_predictions

BIG_NEG = -1e9

## Data

First, let's examine the dataset. GeoQA contains 10 databases, one for each of 10 states.

In [0]:
STATES

['fl', 'ga', 'mi', 'nc', 'ok', 'pa', 'sc', 'tn', 'va', 'wv']

In [0]:
DEV_STATES = ["va", "wv"]
TEST_STATES = ["fl", "ga"]
TRAIN_STATES = [env for env in STATES if env not in DEV_STATES + TEST_STATES]

A database consists of a set of entities, and a collection of facts about these entities. Facts are either categories, which are unary predicates on entities, or relations, which are binary predicates between entity pairs.

In [0]:
# all possible categories that can be used across states
' '.join(CATS)

'city state park island beach ocean lake forest major peninsula capital body water salt fresh place'

In [0]:
# all possible relations that can be used across states
' '.join(RELS)

'in-rel on-rel north-rel south-rel east-rel west-rel northeast-rel southeast-rel southwest-rel northwest-rel border-rel capital-rel near-rel close-rel surround-rel bigger-rel abut-rel along-rel inside-rel contain-rel'

Not all categories and relations will be used in a given database; but here are a representation of the entities, categories, and relations for one state. This dataset is quite small, and the state data below represents one of the larger databases.

In [0]:
world = geoqa.geo.read_world("va")
world.print()

name: va
entities:
	virginia
	virginia_beach
	richmond
	atlantic_ocean
	west_virginia
	north_carolina
categories:
	city(virginia_beach)
	city(richmond)
	state(virginia)
	state(west_virginia)
	state(north_carolina)
	ocean(atlantic_ocean)
relations:
	in-rel(virginia_beach, virginia)
	in-rel(richmond, virginia)
	on-rel(virginia, atlantic_ocean)
	on-rel(virginia_beach, atlantic_ocean)
	on-rel(west_virginia, atlantic_ocean)
	on-rel(north_carolina, atlantic_ocean)
	south-rel(virginia, west_virginia)
	south-rel(north_carolina, virginia)
	south-rel(north_carolina, virginia_beach)
	south-rel(north_carolina, richmond)
	east-rel(virginia, west_virginia)
	east-rel(virginia_beach, richmond)
	east-rel(atlantic_ocean, virginia)
	east-rel(atlantic_ocean, virginia_beach)
	east-rel(atlantic_ocean, richmond)
	east-rel(atlantic_ocean, west_virginia)
	east-rel(atlantic_ocean, north_carolina)
	west-rel(virginia, atlantic_ocean)
	west-rel(virginia_beach, atlantic_ocean)
	west-rel(richmond, virginia_beach)
	w

Each state has a set of question-answer pairs, with an associated logical form for each question.

In [0]:
questions, answers, logical_forms = geoqa.geo.read_data("va")

instance_index = 10

In [0]:
questions[instance_index]

'what cities are in virginia ?'

Each answer is a set of entities from the database:

In [0]:
answers[instance_index]

'virginia_beach,richmond'

Logical forms are lambda calculus functions represented as tuples:

In [0]:
logical_forms[instance_index]

('lambda',
 '$w',
 ('exists',
  '$x',
  ('and', ('city', '$w'), ('in-rel', '$w', '$x'), ('kb-virginia', '$x'))))

In [0]:
print(logical_form_to_str(logical_forms[instance_index]))

(lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-virginia $x))))


This dataset has fairly simple lambda calculus expressions, which always consist of a single function and a conjunction of predicates. Because of this, every logical form contains:
1. **Variables** (`$x, $y, $z,` and `$w`), which can take on entity values from the database.
2. **Predicates**, of which there are three types:
    - **categories** (e.g. `city`), which take 1 variable as an argument, and return true if the category is true in the database for the entity assigned to the variable.
    - **relations** (e.g. `in-rel`), which take 2 variable arguments, and return true if the relation is true in the database for the ordered pair of entities assigned to the variables.
    - **entity predicates** (e.g. `kb-virginia`), which takes 1 variable argument and return true if the proper entity (e.g. `virginia`) is assigned to the variable.
3. A **conjunction** (`and`) of multiple **literals** (e.g. `(city $w)` is a literal). A literal consists of a predicate with some number of variables as argument.
4. [optional] An **existential quantifier** (`exists`), which takes as arguments all but one of the variables used within the conjunction, and the conjunction. `exists` is used if and only if more than one variable is used inside the `and` expression.
5. A **lambda expression** (`lambda`), which takes as arguments a single variable and either a conjunction (if the conjunction contains no other variables, other than the lambda's argument) or an existential quantifier (if it does).

Note: an alternative, more compact logical form representation would use entities directly when applicable, e.g. `(lambda $w (and (city $w) (in-rel $w virginia)))`. This would make the implementation of the logical form executor (which you'll complete in the next section) a bit more complex, but it would likely reduce the difficulty of learning the logical forms' structure. If you're interested, see [Liang 2013](https://arxiv.org/abs/1309.4408) for one approach used in many recent systems which takes this even further to eliminate all variables and make existential quantification implicit, more closely paralleling natural language.

In general, semantic parsing systems must trade off between the complexity of the lexicon, the syntax/semantics interface, compactness of the logical form structure, and the difficulty of inference and learning.  For this dataset, we'll find that a neural model is able to adequately learn  to predict these relatively verbose logical forms directly from sentences, given some lexical and structural constraints.

In [0]:
# obtain all predicates (categories, relations, and entity predicates) from the dataset
PREDICATES = set(CATS) | set(RELS) | set(ENTITY_ACTIONS_TO_ENTITIES.keys())

In [0]:
' '.join(sorted(PREDICATES))

'abut-rel along-rel beach bigger-rel body border-rel capital capital-rel city close-rel contain-rel east-rel forest fresh in-rel inside-rel island kb-alabama kb-arkansas kb-atlanta kb-atlantic_ocean kb-birmingham kb-charleston kb-charlotte kb-daytona_beach kb-detroit kb-everglades_national_park kb-florida kb-francis_marion_national_forest kb-georgia kb-grand_rapids kb-great_smoky_mountains_national_park kb-greensboro kb-greenville kb-harrisburg kb-hilton_head_island kb-key_largo kb-knoxville kb-lake_huron kb-lake_michigan kb-lake_moultrie kb-louisiana kb-macon kb-memphis kb-miami kb-michigan kb-milwaukee kb-mississippi kb-monongahela_national_forest kb-montgomery kb-myrtle_beach kb-nashville kb-new_jersey kb-newark kb-north_carolina kb-oklahoma kb-pennsylvania kb-pittsburgh kb-raleigh kb-richmond kb-south_carolina kb-tallahassee kb-tennessee kb-texas kb-trenton kb-tygart_lake kb-uwharrie_national_forest kb-virginia kb-virginia_beach kb-west_virginia kb-wisconsin lake major near-rel nor

## Logical Form Executor

All logical forms in this dataset are lambda functions, which take database entities and return truth values. Each logical form's denotation (corresponding to the answer to its question) is the set of entities in the database for which the lambda function is true.

In [0]:
questions[instance_index]

'what cities are in virginia ?'

In [0]:
logical_form_to_str(logical_forms[instance_index])

'(lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-virginia $x))))'

In [0]:
answers[instance_index]

'virginia_beach,richmond'

To evaluate a logical form's denotation in a database, we'll run a simple unification algorithm to build a denotation up recursively from sub-trees in the logical form's tree. Each sub-tree in the logical form will be associated with a `VariableAssignments` object, which contains the set of all possible assignments of entities to variables within the sub-tree that would satisfy the logical form:

In [0]:
# VariableAssignments: a collection of possible assignments of entities to
#   variables that satisfy the logical form, as processed so far
VariableAssignments = namedtuple('VariableAssignments', [
  'variables',
  # variables: Set[str]: variables that are assigned
  'assignments'
  # assignments: List[Dict[str, str]]: a list of variable assignments.
  #   each variable assignment dictionary should map every variable in 
  #   `variables` to an entity
])

In the cell below, we've defined operations to produce satisfying assignments for each of the three literal types (which occur at the leaves of the tree):

In [0]:
def entity_op(entity_predicate: str, variable: str, world: World) -> VariableAssignments:
  entity_name = ENTITY_ACTIONS_TO_ENTITIES[entity_predicate]
  assert entity_name in world.entities
  return VariableAssignments({variable}, [{variable: entity_name}])

def category_op(category: str, variable: str, world: World) -> VariableAssignments:
  index = CATS.index(category)
  values = (world.categories[index] == 1).nonzero()[0]
  return VariableAssignments(
    {variable},
    [{variable: world.index_to_entities[value]}
     for value in values],
  )

def relation_op(relation: str, variable_1: str, variable_2: str, world: World) -> VariableAssignments:
  index = RELS.index(relation)
  val_1s, val_2s = (world.relations[index] == 1).nonzero()
  assignments = [
    {variable_1: world.index_to_entities[val_1], variable_2: world.index_to_entities[val_2]}
    for val_1, val_2 in zip(
      val_1s, val_2s
    )
  ]
  return VariableAssignments(
    {variable_1, variable_2},
    assignments
  )

These literal operations are demonstrated below:

In [0]:
world = geoqa.geo.read_world("wv")
world.print()

name: wv
entities:
	west_virginia
	charleston
	monongahela_national_forest
	tygart_lake
	richmond
	virginia
categories:
	city(charleston)
	city(richmond)
	state(west_virginia)
	state(virginia)
	park(monongahela_national_forest)
	lake(tygart_lake)
relations:
	in-rel(charleston, west_virginia)
	in-rel(monongahela_national_forest, west_virginia)
	in-rel(tygart_lake, west_virginia)
	in-rel(richmond, virginia)
	east-rel(monongahela_national_forest, charleston)
	east-rel(richmond, west_virginia)
	east-rel(richmond, monongahela_national_forest)
	east-rel(virginia, west_virginia)
	east-rel(virginia, monongahela_national_forest)
	east-rel(virginia, tygart_lake)
	west-rel(west_virginia, richmond)
	west-rel(west_virginia, virginia)
	west-rel(charleston, monongahela_national_forest)
	west-rel(monongahela_national_forest, richmond)
	west-rel(monongahela_national_forest, virginia)
	west-rel(tygart_lake, virginia)
	capital-rel(charleston, west_virginia)
	capital-rel(richmond, virginia)


In [0]:
logical_form = geoqa.utils.parse_tree('(lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-west_virginia $x))))')

In [0]:
entity_op('kb-west_virginia', '$x', world)

VariableAssignments(variables={'$x'}, assignments=[{'$x': 'west_virginia'}])

In [0]:
relation_op('in-rel', '$w', '$x', world)

VariableAssignments(variables={'$w', '$x'}, assignments=[{'$w': 'charleston', '$x': 'west_virginia'}, {'$w': 'monongahela_national_forest', '$x': 'west_virginia'}, {'$w': 'tygart_lake', '$x': 'west_virginia'}, {'$w': 'richmond', '$x': 'virginia'}])

In [0]:
category_op('city', '$w', world)

VariableAssignments(variables={'$w'}, assignments=[{'$w': 'charleston'}, {'$w': 'richmond'}])

In the cell below, we've defined an `execute` function which will return the entity set denotation for a logical form by recursively traversing the logical form's tree and combining variable assignments from the leaves upward, using _sub-tree operations_.  The root node of the tree, the `lambda` node, has a `lambda_op` sub-tree operation which takes a `VariableAssignments` containing possible satisfying assignments, and returns the set of entities that the lambda variable takes on.

All other sub-tree operations take `VariableAssignments` returned by their children, and return a `VariableAssignments` giving the assignments that satisfy the sub-tree.

Some of the sub-tree operations (`entity_op`, `category_op`, `relation_op`) for literals were defined above; you'll define the others (`lambda_op`, `exists_op`, `and_op`) in the cells afterward.

In [0]:
def execute(logical_form: Tuple, world: World):
  # recursively process the tree from the leaves upward. Each subtree should return:
  #   lambda nodes: a `set` of entities that satisfy the logical form
  #   all other node types: a VariableAssignments that contains possible assignments
  #       of entities to variables that would satisfy the subtree
  assert isinstance(logical_form, tuple)
  predicate = logical_form[0]
  if predicate in ENTITY_ACTIONS_TO_ENTITIES:
    var = logical_form[1]
    return entity_op(predicate, var, world)
  elif predicate in CATS:
    category = predicate
    var = logical_form[1]
    return category_op(category, var, world)
  elif predicate in RELS:
    rel = predicate
    var_1, var_2 = logical_form[1:]
    return relation_op(rel, var_1, var_2, world)
  elif predicate == LAMBDA:
    var = logical_form[1]
    var_assignments = execute(logical_form[2], world)
    return lambda_op(var, var_assignments)
  elif predicate == EXISTS:
    vars = logical_form[1:-1]
    var_assignments = execute(logical_form[-1], world)
    return exists_op(vars, var_assignments)
  elif predicate == AND:
    return and_op([
      execute(sub_lf, world)
      for sub_lf in logical_form[1:]
    ])
  else:
   raise NotImplementedError("invalid predicate {}\n{}".format(predicate, logical_form))

In [0]:
# helper function to use when testing results from the sub-tree operations
def are_equal(va1: VariableAssignments, va2: VariableAssignments) -> bool:
  def canonicalize(assignments):
    # compare sets of assignments, since ordering (and duplication) of assignments within the list does not affect meaning
    # convert each assignment dict to a string in json format since dictionaries are unhashable
    import json
    return set(json.dumps(d) for d in assignments)
  return va1.variables == va2.variables and canonicalize(va1.assignments) == canonicalize(va2.assignments)

Complete the `lambda_op`, `exists_op`, and `and_op` sub-tree operations in the cells below. Each operation has a test cell after it with sample inputs and outputs.

In [0]:
def lambda_op(lambda_variable: str, variable_assignments: VariableAssignments) -> Set[str]:
  # return the entities that variable takes on in variable_assignments
  L = []
  for assignment in variable_assignments.assignments:
      L.append(assignment[lambda_variable])
  return set(L)

In [0]:
assert lambda_op('$y', VariableAssignments(
    {'$x', '$y'}, 
    [{'$x': 'virginia', '$y': 'richmond'}, {'$x': 'virginia', '$y': 'charleston'}]
  )) == {'richmond', 'charleston'}

In [0]:
def exists_op(existentially_quantified_vars: List[str], child_assignments: VariableAssignments) -> VariableAssignments:
  # remove the existentially-quantified variables from the assignments in variable_assignments
  new_variables = child_assignments.variables - set(existentially_quantified_vars)
  new_assignments = []
  for assignment in child_assignments.assignments:
      row = {}
      for key in assignment.keys():
          if key in new_variables:
              row[key] = assignment[key]
          new_assignments.append(row)
  return VariableAssignments(new_variables, new_assignments)

In [0]:
exists_test_1 = exists_op(['$x'], VariableAssignments(
    {'$x', '$y'}, [{'$x': 'virginia', '$y': 'richmond'}, {'$x': 'virginia', '$y': 'charleston'}]
  ))
assert are_equal(exists_test_1, VariableAssignments(
    {'$y'}, [{'$y': 'richmond'}, {'$y': 'charleston'}]
)), exists_test_1
                                             
exists_test_2 = exists_op(['$y'], VariableAssignments(
    {'$x', '$y'}, [{'$x': 'virginia', '$y': 'richmond'}, {'$x': 'virginia', '$y': 'charleston'}]
  ))
assert are_equal(exists_test_2, VariableAssignments(
    {'$x'},
    [{'$x': 'virginia'}]
)), exists_test_2

In [0]:
dic = {}
dic['a'] = 0
dic['b'] = 1
type(dic.keys())

dict_keys

In [0]:
def and_op(children_assignments: List[VariableAssignments]) -> VariableAssignments:
  # join the assignments in children_assignments, to return possible assignments that satisfy all children
  def and_for_2_assignments(assignment1, assignment2, common_vars):
      for var in common_vars:
          if assignment1[var] != assignment2[var]:
              return None
      new_assignment = assignment1.copy()
      for var in assignment2.keys():
          if var not in common_vars:
              new_assignment[var] = assignment2[var]
      return new_assignment

  possible_assignments = []
  if len(children_assignments) == 2:
      common_vars = children_assignments[0].variables.intersection(children_assignments[1].variables)
      for assignment1 in children_assignments[0].assignments:
          for assignment2 in children_assignments[1].assignments:
              new_assignment = and_for_2_assignments(assignment1, assignment2, common_vars)
              if new_assignment != None:
                possible_assignments.append(new_assignment)
      new_vars = children_assignments[0].variables.union(children_assignments[1].variables)
      return VariableAssignments(new_vars, possible_assignments)
      
  elif len(children_assignments) > 2:
      new_children = [children_assignments[0], and_op(children_assignments[1:])]
      return and_op(new_children)
  elif len(children_assignments) == 1:
      return children_assignments[0]

In [0]:
and_test = and_op([
  VariableAssignments({'$x', '$y'}, [{'$x': 'x1', '$y': 'y1'}, {'$x': 'x2', '$y': 'y2'}]),
  VariableAssignments({'$z', '$y'}, [{'$z': 'z2', '$y': 'y2'}, {'$z': 'x3', '$y': 'y3'}]),                 
])
assert are_equal(
  and_test,
  VariableAssignments(variables={'$x', '$y', '$z'}, assignments=[{'$x': 'x2', '$y': 'y2', '$z': 'z2'}])
)

With correct sub-tree operation definitions, the following test cell should run without errors. It checks that the entity set denotations returned by the `execute` function match the answers in the database.

In [0]:
for state in STATES:
  world = geoqa.geo.read_world(state)
  for question, answer, logical_form in zip(*geoqa.geo.read_data(state)):
    if logical_form is None:
      continue
    true_denotation = world.get_denotation_from_answer(answer)
    executed_denotation = execute(logical_form, world)
    
    if true_denotation != executed_denotation:
      print("execution failure!")
      print("question: {}".format(question))
      print("state: {}".format(world))
      print("logical form: {}".format(logical_form))
      print("true denotation: {}".format(true_denotation))
      print("executed denotation: {}".format(executed_denotation))
      raise ValueError()

## Parser

Now that we have a method to execute logical forms on the database, we'll spend the rest of this assignment constructing parsers to produce logical forms from questions. 

Our question-to-logical-form parser will be a neural sequence-to-sequence model that will encode the question and decode a sequence of actions to construct a logical form. We will use _constrained decoding_ to prevent the model from constructing invalid logical forms, enforcing structural constraints rather than relying on the model to learn them.

To construct this parser, we first need to define a transition system: a set of actions that build up a logical form incrementally, and the effects each action has on a partially constructed logical form. This transition system will constrain the available actions that can be taken by the sequence-to-sequence model at any point in time to guarantee valid logical forms.

### Transition system

We'll construct logical forms using a depth-first, left-to-right traversal of the logical form's tree structure, with actions in a post-fix notation (so a predicate comes after the arguments it takes), and a final `DONE` action after the logical form is constructed. For example, the logical form

`(lambda $x (exists $w (and (state $w) (city $x) (in-rel $x $w))))`

will be constructed in the traversal order:

`$w state $x city $x $w in-rel and $w exists $x lambda DONE`

We will basically have one action for each of the items in the order above, with one change to simplify what our model needs to learn. In this dataset, which always combines literals with a single `and`, each logical form has exactly one variable as an argument to the lambda expression, and existentially quantifies all other variables. Because of this, the actions between `and` and the `lambda` are determined by the lambda's variable argument (`$x`) and the other variables that have previously been introduced (all variables other than the lambda variable must be existentially quantified). This allows us to replace these actions with a single action, `complete_lambda_$x`. For the example above:

`$w state $x city $x $w in-rel complete_lambda_$x DONE`

If the question had been about the variable `$w`, we would instead use `complete_lambda_$w` so that `$w` was used with `lambda` and everything else was existentially quantified.

Note that other datasets with more complex logical forms would need to generate conjunctions and quantifiers individually, rather than using these simplifying `complete_lambda_` actions.

### Actions

We'll now define the actions used by our transition system. Variables and predicates will be taken directly from the items in the logical forms. We'll define new actions for `COMPLETE_LAMBDA_*` and `DONE`:

In [0]:
# mapping from complete lambda action to the lambda's variable
COMPLETE_LAMBDA_ACTIONS_TO_VARS = {'COMPLETE_LAMBDA_{}'.format(var): var for var in VARS}
# and reverse
VARS_TO_COMPLETE_LAMBDA_ACTIONS = {v: k for k, v in COMPLETE_LAMBDA_ACTIONS_TO_VARS.items()}

DONE_ACTION = 'DONE'

Actions will be divided into three types, which will be produced differently by our model.

In [0]:
STRUCTURAL_ACTIONS = VARS + sorted(COMPLETE_LAMBDA_ACTIONS_TO_VARS.keys()) + [DONE_ACTION]
CROSS_DATABASE_PREDICATE_ACTIONS = CATS + RELS
DATABASE_SPECIFIC_PREDICATE_ACTIONS = list(ENTITY_ACTIONS_TO_ENTITIES.keys())

ACTIONS = set(STRUCTURAL_ACTIONS) | set(CROSS_DATABASE_PREDICATE_ACTIONS) | set(ENTITY_ACTIONS_TO_ENTITIES.keys())

In [0]:
print(STRUCTURAL_ACTIONS)

['$w', '$x', '$y', '$z', 'COMPLETE_LAMBDA_$w', 'COMPLETE_LAMBDA_$x', 'COMPLETE_LAMBDA_$y', 'COMPLETE_LAMBDA_$z', 'DONE']


In [0]:
print(CROSS_DATABASE_PREDICATE_ACTIONS[:5])

['city', 'state', 'park', 'island', 'beach']


In [0]:
print(DATABASE_SPECIFIC_PREDICATE_ACTIONS[:5])

['kb-florida', 'kb-tallahassee', 'kb-miami', 'kb-everglades_national_park', 'kb-key_largo']


### From actions to logical forms

This `ParseConstraints` class will contain possible options to constrain the logical form being constructed, which will become important later when we begin searching over logical forms.

In [0]:
ParseConstraints = namedtuple('ParseConstraints', [
  'possible_predicates',
  # set[str]: set of predicates to allow in this logical form
  'max_vars',
  # int: maximum number of distinct variables to allow in the logical form
  #   the maximum number in the dataset is 4 (len(VARS))
  'no_repeated_literals',
  # bool: if True, don't allow the same literal to appear more than once
  #   e.g. (and (state $x) (state $x)) is disallowed, but (and (state $x) (state $y)) is ok
  'max_literals'
  # int: disallow more than this many literals
  #   e.g. if max_literals == 1, (and (state $x) (city $y)) is disallowed
  #   the maximum number in the dataset is 7 (MAX_LITERALS)
])

To build up a logical form via a sequence of actions, we define a `ParserState` class several blocks below. This will represent the logical form as constructed so far and compute valid actions that can be taken to continue constructing it. The `_ParserState` class here defines the actual state variables, then you will fill in the methods in `ParserState` (without the underscore).

In [0]:
_ParserState = namedtuple('_ParserState', [
  'variable_argument_stack',
  # Stack[str]: variables to be put into the current literal
  'completed_literal_stack',
  # Stack[tuple]: all literals produced so far
  'has_lambda',
  # bool: whether the lambda has been generated
  'lambda_var',
  # str or None: the lambda's variable argument, if lambda has been generated; or None otherwise
  'is_complete',
  # bool: whether the DONE action has been generated
  'all_vars_introduced',
  # Stack[str]: variables should be appened to this each time they are used
  'past_actions',
  # Stack[str]: all previous actions taken
  'parse_constraints',
  # ParseConstraints: options to constrain the logical form being constructed
])

`variable_argument_stack` contains any variables that  will be put into the literal which is currently being constructed. This is reset to empty when a literal is completed (by a predicate action). All literals, as they are completed, are added to `completed_literal_stack`. This should not be emptied at any point, as it will be used to construct the final logical form.

For example, when generating the logical form 

`(lambda $x (exists $w (and (state $w) (city $x) (in-rel $x $w)))`

using the action sequence

`$w state $x city $x $w in-rel complete_lambda_$x DONE`,

the `ParserState` will look as follows after the first `$x` action:
```
ParserState(
  variable_argument_stack=['$x'],
  completed_literal_stack=[('state', '$w')],
  has_lambda=False,
  lambda_var=None,
  is_complete=False,
  all_vars_introduced=['$w', '$x'],
  past_actions=['$w', 'state', '$x'],
  parse_constraints=...
)
```
and as follows after the `DONE` action:
```
ParserState(
  variable_argument_stack=[],
  completed_literal_stack=[('state', '$w'), ('city', '$x'), ('in-rel', '$x', '$w')],
  has_lambda=True,
  lambda_var=$x,
  is_complete=True,
  all_vars_introduced=['$w', '$x', '$x', '$w'],
  past_actions=['$w', 'state', '$x', 'city', '$x', '$w', 'in-rel', 'COMPLETE_LAMBDA_$x', 'DONE'],
  parse_constraints=...
)
```

We'll now define a `ParserState` wrapper class with methods. Complete the `take_action` method in the `ParserState` class below.

Implementation tip: `ParserState` objects are immutable, and the `take_action` method should return a new `ParserState` object, rather than attempting to update the object (this will allow `ParserState` objects to be used in beam search). The [Stack](https://github.com/dpfried/geoqa-release/blob/master/geoqa/utils.py#L7) class used in `ParserState` for `variable_argument_stack`, `completed_literal_stack`, `all_vars_introduced`, and `past_actions` is an immutable data structure that is essentially like a list, but will allow sharing some memory across multiple `ParserState` objects as search is performed, for efficiency and to prevent copying. See the cell below for a demonstration:


In [0]:
# demonstration of the Stack class used in _ParserState
actions = Stack.empty().append('$w')
print(actions.tolist())
actions_plus_city = actions.append('city')
actions_plus_state = actions.append('state')
print(actions.tolist()) # hasn't changed
print(actions_plus_city.tolist()) # shares '$w' with actions
print(actions_plus_state.tolist()) # shares '$w' with actions
print(actions_plus_city.size) # 2

['$w']
['$w']
['$w', 'city']
['$w', 'state']
2


In [0]:
class ParserState(_ParserState):
  """
  Wrapper class to add methods and constants to _ParserState
  """
  PREDICATE_ARITIES = {}
  for cat in CATS:
    PREDICATE_ARITIES[cat] = 1
  for rel in RELS:
    PREDICATE_ARITIES[rel] = 2
  for entity in ENTITY_ACTIONS_TO_ENTITIES.keys():
    PREDICATE_ARITIES[entity] = 1

  MAX_ARITY = max(PREDICATE_ARITIES.values())

  @staticmethod
  def initial_state(parse_constraints):
    return ParserState(
      variable_argument_stack=Stack.empty(),
      completed_literal_stack=Stack.empty(),
      has_lambda=False,
      lambda_var=None,
      is_complete=False,
      all_vars_introduced=Stack.empty(),
      past_actions=Stack.empty(),
      parse_constraints=parse_constraints,
    )

  def take_action(self, action_symbol: str):
    assert action_symbol in ACTIONS, "invalid action {}".format(action_symbol)
    assert not self.is_complete

    assert action_symbol in self.valid_actions(), "trying to take {} but only valid actions are {}.\npast actions: {}".format(
      action_symbol, self.valid_actions(), ' '.join(self.actions()))

    # initialize all variables that might be updated
    past_actions = self.past_actions
    has_lambda = self.has_lambda
    is_complete = self.is_complete
    completed_literal_stack = self.completed_literal_stack
    variable_argument_stack = self.variable_argument_stack
    all_vars_introduced = self.all_vars_introduced
    lambda_var = self.lambda_var

    past_actions = past_actions.append(action_symbol)

    if action_symbol in PREDICATES:
      action = [action_symbol] + variable_argument_stack.tolist()
      completed_literal_stack = completed_literal_stack.append(tuple(action))
      variable_argument_stack = Stack.empty()

    elif action_symbol in VARS:
      variable_argument_stack = variable_argument_stack.append(action_symbol)
      if action_symbol not in all_vars_introduced:
        all_vars_introduced = all_vars_introduced.append(action_symbol)

    elif action_symbol in COMPLETE_LAMBDA_ACTIONS_TO_VARS.keys():
      lambda_var = action_symbol[-2:]
      has_lambda = True

    elif action_symbol == DONE_ACTION:
      is_complete = True

    else:
      raise ValueError("invalid action {}".format(action_symbol))

    return ParserState(
      variable_argument_stack=variable_argument_stack,
      completed_literal_stack=completed_literal_stack,
      has_lambda=has_lambda,
      lambda_var=lambda_var,
      is_complete=is_complete,
      all_vars_introduced=all_vars_introduced,
      past_actions=past_actions,
      parse_constraints=self.parse_constraints,
    )

  def variables_introduced(self) -> Set[str]:
    """
    :return: variables that have been used in the logical form, as constructed so far
    """
    variables = set(self.all_vars_introduced.tolist())
    assert variables == set(VARS[:len(
      variables)]), "variables should not be introduced out of the sequential order {}, but {} were used".format(VARS,
                                                                                                                 vars)
    return variables

  def variables_usable(self) -> Set[str]:
    """
    To reduce the set of possible logical forms, enforce that new variables are introduced in alphabetical order, e.g. $x cannot be used before $w has been used
    Usable variables include all previously used variables and the next in alphabetical order
    :return: set of possible next variables to use
    """
    num_vars_introduced = len(self.variables_introduced())
    return set(VARS[:min(self.parse_constraints.max_vars, num_vars_introduced + 1)])

  def to_logical_form(self) -> Tuple:
    """
    Convert a completed parser state to a logical form
    :return: logical form as a tuple
    """
    assert self.is_complete

    # list of tuples
    clause_literals = self.completed_literal_stack.tolist()

    assert len(clause_literals) > 0
    logical_form = (AND,) + tuple(clause_literals)

    assert self.lambda_var is not None
    vars_to_quantify = tuple(sorted(v for v in self.variables_introduced() if v != self.lambda_var))
    if len(vars_to_quantify) > 0:
      logical_form = (EXISTS,) + vars_to_quantify + (logical_form,)
    logical_form = (LAMBDA, self.lambda_var, logical_form)

    return logical_form

  @staticmethod
  def from_logical_form(logical_form: Tuple, parse_constraints: ParseConstraints=None):
    """
    :param logical_form:
    :param parse_constraints: used to check the logical form
    :return:
    """
    # depth-first post-fix traversal, where actions are node labels, except we
    # (1) collapse (lambda $LAMBDA_VAR (exists $VAR1 $VAR2 to COMPLETE_LAMBDA_$LAMBDA_VAR
    # (2) do not produce an action for the AND label
    # (3) add a DONE_ACTION action at the end (root)
    if parse_constraints is None:
      parse_constraints = ParseConstraints(
        possible_predicates=None, max_vars=len(VARS), no_repeated_literals=False,
        max_literals=MAX_LITERALS
      )

    def _traverse(node: Union[tuple, str], state: ParserState):
      if isinstance(node, tuple):
        label = node[0]
        if label == LAMBDA:
          lambda_var = node[1]
          assert lambda_var in VARS
          children = node[2:]
          assert children
        elif label == EXISTS:
          children = node[1:]
          while children[0] in VARS:
            children = children[1:]
          assert children
        else:
          children = node[1:]
      else:  # is a leaf; should be a variable
        label = node
        children = []
      for child in children:
        state = _traverse(child, state)
      if label == LAMBDA:
        state = state.take_action(VARS_TO_COMPLETE_LAMBDA_ACTIONS[lambda_var])
      elif label != AND and label != EXISTS:
        state = state.take_action(label)
      return state

    state = ParserState.initial_state(parse_constraints)
    state = _traverse(logical_form, state)
    state = state.take_action(DONE_ACTION)
    assert state.is_complete
    return state

  def actions(self) -> List[str]:
    return self.past_actions.tolist()

  def valid_actions(self) -> List[str]:
    # actions that can be taken if we've finished building the lambda expression
    if self.is_complete:
      return []
    if self.has_lambda:
      return [DONE_ACTION]

    # otherwise, deal with each possible action type in turn
    valid_actions = []

    # complete the lambda expression
    if self.variable_argument_stack.size == 0 and self.completed_literal_stack.size > 0:
      for var in self.variables_introduced():
        valid_actions.append(VARS_TO_COMPLETE_LAMBDA_ACTIONS[var])

    # introduce a predicate, with arguments from argument_stack
    if self.variable_argument_stack.size > 0:
      possible_predicates = self.parse_constraints.possible_predicates
      if possible_predicates is None:
        possible_predicates = CATS + RELS + list(ENTITY_ACTIONS_TO_ENTITIES.keys())
      possible_predicates = list(sorted(possible_predicates))

      for predicate in possible_predicates:
        if ParserState.PREDICATE_ARITIES[predicate] == self.variable_argument_stack.size:
          valid_actions.append(predicate)

    # add a variable to the argument stack
    if self.variable_argument_stack.size < ParserState.MAX_ARITY and (
        # only start a new literal if we're strictly less than the maximum
        (self.variable_argument_stack.size == 0 and self.completed_literal_stack.size < self.parse_constraints.max_literals) or
        # only continue an existing literal if we're no greater than the maximum
        (self.variable_argument_stack.size > 0 and self.completed_literal_stack.size <= self.parse_constraints.max_literals)
    ):
      valid_actions.extend(self.variables_usable())

    return valid_actions

  def __repr__(self):
    return """ParserState(
variable_argument_stack={},
completed_literal_stack={},
has_lambda={},
lambda_var={},
is_complete={},
all_vars_introduced={},
past_actions={},
parse_constraints={},
)""".format(*self)

With a correct implementation of `take_action`, the following three test cells should run without errors.

In [0]:
permissive_parse_constraints = ParseConstraints(
    possible_predicates=PREDICATES, 
    max_vars=len(VARS),
    no_repeated_literals=False, 
    max_literals=MAX_LITERALS,
)

In [0]:
blank_parser_state = ParserState.initial_state(permissive_parse_constraints)
parser_state = blank_parser_state.take_action('$w')
assert parser_state.variable_argument_stack.tolist() == ['$w']
parser_state = parser_state.take_action('$x')
assert parser_state.variable_argument_stack.tolist() == ['$w', '$x']
parser_state = parser_state.take_action('in-rel')
assert parser_state.variable_argument_stack.tolist() == []
assert parser_state.completed_literal_stack.tolist() == [('in-rel', '$w', '$x')]
assert parser_state.actions() == ['$w', '$x', 'in-rel']

In [0]:
for state in STATES:
  world = geoqa.geo.read_world(state)
  for question, answer, logical_form in zip(*geoqa.geo.read_data(state)):
    if logical_form is None:
      continue
    parser_state = ParserState.from_logical_form(logical_form, permissive_parse_constraints)
    round_trip_logical_form = parser_state.to_logical_form()
    assert logical_form == round_trip_logical_form, "\n{} !=\n{}".format(logical_form, round_trip_logical_form)

## Preprocessing

Next we'll define a dataset class. Since this dataset is small, and to allow evaluating on databases with entities we didn't see in training, we'll use pre-trained embeddings to represent both words and predicates in the databases. We've downloaded and pre-filtered [fasttext](https://fasttext.cc/) word embeddings trained on English Wikipedia and news, and we'll use the first 50 dimensions to keep our models small.

In [0]:
EMBEDDING_DIM = 50
word_to_embedding = geoqa.utils.get_word_vectors(
    'vecs/wiki-news-300d-1M-subword-filtered.vec', max_dim=EMBEDDING_DIM
)

In [0]:
word_to_embedding['virginia']

array([ 0.0349, -0.0109,  0.015 ,  0.0392,  0.0173,  0.0048, -0.0055,
       -0.0322, -0.0006,  0.0285, -0.0186, -0.0159,  0.0106, -0.014 ,
       -0.0161, -0.0116, -0.0037, -0.0122,  0.0025,  0.0176, -0.0033,
       -0.0165, -0.0129,  0.0204,  0.0006, -0.0073, -0.0179,  0.029 ,
        0.0154,  0.0115, -0.0019, -0.0179,  0.0169,  0.0223, -0.0021,
        0.0126,  0.0087, -0.0016,  0.0091,  0.001 ,  0.0006, -0.0081,
       -0.0206, -0.0103, -0.0138,  0.0247,  0.0036, -0.0016,  0.0018,
        0.0123])

We've defined a [`GeoDataset`](https://github.com/dpfried/geoqa-release/blob/master/geoqa/dataset.py#L16) class for you in the support code to preprocess the data (don't worry about reading the code though; we'll demonstrate it below). This class tokenizes each question, and associated with each token a list of predicates that could possibly be used in the logical form, based on string matching predicate names against the word. To constrain the search space over possible logical forms, we will only allow predicates for a given question from the union of all its tokens' predicates. That we can do this at all is a somewhat unique feature of this dataset; more complex semantic parsing problems also require learning more of the lexicon as well.

In [0]:
sample_dataset = GeoDataset(state_names=['wv', 'va'], word_to_embedding=word_to_embedding)

In [0]:
def print_instance(instance, print_tokens=False):
  for key in ['question', 'answer', 'logical_form', 'world', 'possible_predicates', 'denotation']:
    if key == 'logical_form':
      rep = logical_form_to_str(instance[key])
    else:
      rep = instance[key]
    print('{:<19}: {}'.format(key, rep))
  if print_tokens:
    print()
    print('{:<10}: {}'.format('word', 'predicates'))
    print('='*23)
    for word, predicates, embedded_predicates in zip(
      instance['words'], 
      instance['predicates_at_each_word_position'], 
      instance['embedded_predicates_at_each_word_position']
    ):
      print('{:<10}: {}'.format(word, ', '.join(predicates)))

In [0]:
sample_instance = sample_dataset[30]

In [0]:
print_instance(sample_instance, print_tokens=True)

question           : what ocean borders virginia ?
answer             : atlantic_ocean
logical_form       : (lambda $w (exists $x (and (ocean $w) (border-rel $w $x) (kb-virginia $x))))
world              : <World: va with 6 entities, 6 categories, 72 relations>
possible_predicates: ['border-rel', 'kb-atlantic_ocean', 'kb-virginia', 'kb-virginia_beach', 'kb-west_virginia', 'ocean']
denotation         : {'atlantic_ocean'}

word      : predicates
what      : 
ocean     : ocean, kb-atlantic_ocean
borders   : border-rel
virginia  : kb-virginia, kb-virginia_beach, kb-west_virginia
?         : 


## Model

In this section we'll define a neural sequence-to-sequence model that encodes a question and outputs the actions needed to produce its logical form.

We will use a `_ModelState` class to wrap a `ParserState` and also contain a model's hidden state (for the sequence-to-sequence decoder) and the log probabilities of the next possible actions that can be taken from the parser state.

In [0]:
_ModelState = namedtuple('_ModelState', (
  'parser_state',
  # ParserState: the current parser state
  'hidden_state',
  # a tuple (h, c) of pytorch tensors
  'action_log_probs'
  # Dict[str, tensor]: a mapping from next possible actions (parser_state.valid_actions) to
  # a pytorch scalar giving the log probability of that action
))

The `ModelState` class will inherit from `_ModelState` and add a `take_action` method which creates a new `ModelState`, updated using a forward pass of the model and the `take_action` method of the `ParserState`.

Here's a DummyModel which always produces a uniform probability distribution over all available actions:

In [0]:
class DummyModel(nn.Module):
  # model that produces scores of zero always; use to explore logical forms
  def __init__(self, max_vars=len(VARS), no_repeated_literals=False, max_literals=MAX_LITERALS):
    super(DummyModel, self).__init__()
    self.max_vars = max_vars
    self.max_literals = max_literals
    self.no_repeated_literals = no_repeated_literals

  def forward(self, actions_to_score):
    action_log_probs = {
      # a scalar tensor (dimension 0)
      action: torch.tensor(1.0/len(actions_to_score)).log()
      for action in actions_to_score
    }
    hidden_state = None
    return action_log_probs, hidden_state

  def initialize_model_state(self, instance):
    parse_constraints = ParseConstraints(
      possible_predicates=instance['possible_predicates'],
      max_vars=self.max_vars,
      no_repeated_literals=self.no_repeated_literals,
      max_literals=self.max_literals,
    )
    # define self as a variable so that we can reference it inside the DummyModelState class without using self;
    model = self

    class DummyModelState(_ModelState):
      @staticmethod
      def _create_from_parser_state(parser_state):
        possible_actions = parser_state.valid_actions()
        action_log_probs, hidden_state = model.forward(possible_actions)
        return DummyModelState(parser_state, hidden_state, action_log_probs)

      def take_action(self, action: str):
        new_parser_state = self.parser_state.take_action(action)
        return DummyModelState._create_from_parser_state(new_parser_state)

    parser_state = ParserState.initial_state(parse_constraints)
    return DummyModelState._create_from_parser_state(parser_state)

We can use this DummyModel to explore the space of possible (constrained) logical forms for a given instance.

In [0]:
def greedy_search(model, instance):
  state = model.initialize_model_state(instance)
  log_probs = []
  # when we disallow repeated literals, we can get to parser states with no available actions
  # so we need to check to ensure there are scored items in `state.action_log_probs`
  while state.action_log_probs.items() and not state.parser_state.is_complete:
    action, log_prob = max(state.action_log_probs.items(), key=lambda tuple: tuple[1].item())
    log_probs.append(log_prob)
    state = state.take_action(action)
  return state, sum(log_probs)

In [0]:
print_instance(sample_instance, print_tokens=True)

question           : what ocean borders virginia ?
answer             : atlantic_ocean
logical_form       : (lambda $w (exists $x (and (ocean $w) (border-rel $w $x) (kb-virginia $x))))
world              : <World: va with 6 entities, 6 categories, 72 relations>
possible_predicates: ['border-rel', 'kb-atlantic_ocean', 'kb-virginia', 'kb-virginia_beach', 'kb-west_virginia', 'ocean']
denotation         : {'atlantic_ocean'}

word      : predicates
what      : 
ocean     : ocean, kb-atlantic_ocean
borders   : border-rel
virginia  : kb-virginia, kb-virginia_beach, kb-west_virginia
?         : 


In [0]:
model_state, total_log_prob = greedy_search(
    DummyModel(max_vars=4, no_repeated_literals=False),
    sample_instance
    )
print(logical_form_to_str(model_state.parser_state.to_logical_form()))
print(total_log_prob)

(lambda $w (and (kb-atlantic_ocean $w)))
tensor(-3.0445)


Now, we've defined an (unbatched) beam search for you in the two cells below. Only a small number of actions will be allowed to be taken at any point in time by our transition system, so we've chosen to use dictionaries to map actions to scores:

In [0]:
def k_best(score_dict, k):
  k = min(k, len(score_dict))
  if k == 0:
    return {}
  actions, logits = zip(*score_dict.items())
  logits = torch.stack(logits, dim=-1).flatten()
  chosen_scores, indices = logits.topk(k, dim=-1)
  return {
    actions[index.item()]: score
    for score, index in zip(chosen_scores, indices)
  }

In [0]:
def beam_search(model, instance, beam_size):
  state = model.initialize_model_state(instance)
  completed = []
  # each beam item will be a tuple (state, cumulative log probs)
  beam = [(state, torch.tensor(0.0))]
  while beam and len(completed) < beam_size:
    successors = []
    for state, log_prob in beam:
      scored_successors = k_best(state.action_log_probs, beam_size)
      for action, action_log_prob in scored_successors.items():
        successor_log_prob = log_prob + action_log_prob
        successors.append((successor_log_prob, state, action))
    # when we disallow repeated literals, we can get to parser states with no available actions
    # so we need to check to see if there are successors available
    if not successors:
      break
    log_probs = torch.stack([log_prob for log_prob, _, _ in successors], dim=-1).flatten()
    if beam_size is None:
      indices = torch.arange(len(successors))
    else:
      _, indices = log_probs.topk(min(len(successors), beam_size), dim=-1)
    new_beam = []
    for ix in indices:
      log_prob, prev_state, action = successors[ix.item()]
      state = prev_state.take_action(action)
      if state.parser_state.is_complete:
        completed.append((state, log_prob))
      else:
        new_beam.append((state, log_prob))
    beam = new_beam

  completed = sorted(completed, key=lambda t: t[1].item(), reverse=True)
  completed = completed[:beam_size]
  return completed

In [0]:
def sample_search_model(model, beam_size):
  """
  A convenience method for printing out beam search outputs
  """
  for model_state, total_log_prob in beam_search(model, sample_instance, beam_size):
      # with the transition system we've defined,
      # it's possible to get into a parser state with no available actions
      # so we need to check to ensure the parser state is complete before
      # converting it to a logical form
    if model_state.parser_state.is_complete:
      logical_form = model_state.parser_state.to_logical_form()
      print("{:.2f} {}".format(total_log_prob.item(), logical_form_to_str(logical_form)))
    else:
      print("incomplete logical form")

In [0]:
print_instance(sample_instance)

question           : what ocean borders virginia ?
answer             : atlantic_ocean
logical_form       : (lambda $w (exists $x (and (ocean $w) (border-rel $w $x) (kb-virginia $x))))
world              : <World: va with 6 entities, 6 categories, 72 relations>
possible_predicates: ['border-rel', 'kb-atlantic_ocean', 'kb-virginia', 'kb-virginia_beach', 'kb-west_virginia', 'ocean']
denotation         : {'atlantic_ocean'}


In [0]:
sample_search_model(DummyModel(max_vars=1, no_repeated_literals=False), beam_size=20)

-2.48 (lambda $w (and (kb-virginia_beach $w)))
-2.48 (lambda $w (and (kb-atlantic_ocean $w)))
-2.48 (lambda $w (and (kb-virginia $w)))
-2.48 (lambda $w (and (kb-west_virginia $w)))
-2.48 (lambda $w (and (ocean $w)))
-2.48 (lambda $w (and (border-rel $w $w)))
-4.97 (lambda $w (and (kb-atlantic_ocean $w) (kb-virginia $w)))
-4.97 (lambda $w (and (kb-virginia $w) (kb-west_virginia $w)))
-4.97 (lambda $w (and (kb-virginia $w) (ocean $w)))
-4.97 (lambda $w (and (kb-virginia $w) (kb-atlantic_ocean $w)))
-4.97 (lambda $w (and (kb-virginia_beach $w) (ocean $w)))
-4.97 (lambda $w (and (kb-virginia $w) (border-rel $w $w)))
-4.97 (lambda $w (and (kb-atlantic_ocean $w) (border-rel $w $w)))
-4.97 (lambda $w (and (border-rel $w $w) (ocean $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-atlantic_ocean $w)))
-4.97 (lambda $w (and (kb-virginia_beach $w) (border-rel $w $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-virginia $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-west_virginia $w)))
-4.

You may notice that some literals are repeated within a single logical form, e.g. `... (border-rel $w $w) (border-rel $w $w) ...` (due to randomness, you may not see this exact example in the candidates above). Since these identical literals occur inside a conjunction, they're logically redundant. Because of this, we can prevent duplicate literals to reduce the search space, without affecting the expressiveness of the logical forms.

Modify the `valid_actions` function below to prevent duplicate literals, if `self.parse_constraints.no_repeated_literals` is `True`. (We've copied this function from `ParserState.valid_actions`, and will monkey-patch `ParserState` with it, so that you don't have to scroll back and forth.)

In [0]:
def valid_actions(self: ParserState) -> List[str]:
    # actions that can be taken if we've finished building the lambda expression
    if self.is_complete:
      return []
    if self.has_lambda:
      return [DONE_ACTION]

    # otherwise, deal with each possible action type in turn
    valid_actions = []

    # complete the lambda expression
    if self.variable_argument_stack.size == 0 and self.completed_literal_stack.size > 0:
      for var in self.variables_introduced():
        valid_actions.append(VARS_TO_COMPLETE_LAMBDA_ACTIONS[var])

    # introduce a predicate, with arguments from argument_stack
    if self.variable_argument_stack.size > 0:
      possible_predicates = self.parse_constraints.possible_predicates
      if possible_predicates is None:
        possible_predicates = CATS + RELS + list(ENTITY_ACTIONS_TO_ENTITIES.keys())
      possible_predicates = list(sorted(possible_predicates))

      for predicate in possible_predicates:
        if ParserState.PREDICATE_ARITIES[predicate] == self.variable_argument_stack.size:
          if self.parse_constraints.no_repeated_literals:
            # TODO: modify this method so that no repeated literals are allowed in the logical form
            action = [predicate] + self.variable_argument_stack.tolist()
            if tuple(action) not in self.completed_literal_stack:
                valid_actions.append(predicate)

    # add a variable to the argument stack
    if self.variable_argument_stack.size < ParserState.MAX_ARITY and (
        # only start a new literal if we're strictly less than the maximum
        (self.variable_argument_stack.size == 0 and self.completed_literal_stack.size < self.parse_constraints.max_literals) or
        # only continue an existing literal if we're no greater than the maximum
        (self.variable_argument_stack.size > 0 and self.completed_literal_stack.size <= self.parse_constraints.max_literals)
    ):
      valid_actions.extend(self.variables_usable())
    return valid_actions

ParserState.valid_actions = valid_actions

With a correct implementation, you should now see that literals are not repeated:

In [0]:
sample_search_model(DummyModel(max_vars=1, no_repeated_literals=True), beam_size=20)

-2.48 (lambda $w (and (kb-virginia_beach $w)))
-2.48 (lambda $w (and (kb-atlantic_ocean $w)))
-2.48 (lambda $w (and (kb-virginia $w)))
-2.48 (lambda $w (and (kb-west_virginia $w)))
-2.48 (lambda $w (and (ocean $w)))
-2.48 (lambda $w (and (border-rel $w $w)))
-4.79 (lambda $w (and (kb-virginia $w) (kb-atlantic_ocean $w)))
-4.79 (lambda $w (and (kb-atlantic_ocean $w) (kb-virginia_beach $w)))
-4.79 (lambda $w (and (kb-virginia_beach $w) (kb-west_virginia $w)))
-4.79 (lambda $w (and (kb-virginia $w) (kb-west_virginia $w)))
-4.79 (lambda $w (and (ocean $w) (kb-virginia $w)))
-4.79 (lambda $w (and (kb-virginia_beach $w) (border-rel $w $w)))
-4.79 (lambda $w (and (kb-atlantic_ocean $w) (border-rel $w $w)))
-4.79 (lambda $w (and (kb-virginia $w) (border-rel $w $w)))
-4.97 (lambda $w (and (border-rel $w $w) (ocean $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-atlantic_ocean $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-virginia $w)))
-4.97 (lambda $w (and (border-rel $w $w) (kb-virgin

The test cell below, which checks to ensure that there are no repeated literals, should pass.

In [0]:
for model_state, score in beam_search(DummyModel(max_vars=1, no_repeated_literals=True), sample_instance, 1000):
  completed_literals = model_state.parser_state.completed_literal_stack.tolist()
  # check for duplicate literals
  assert len(completed_literals) == len(set(completed_literals))

And the test from before of logical form -> parser state -> logical form (now using `no_repeated_literals=True`) should still pass:

In [0]:
no_repeats_parse_options = ParseConstraints(
    possible_predicates=PREDICATES, 
    max_vars=len(VARS), 
    no_repeated_literals=True, 
    max_literals=MAX_LITERALS
)
for logical_form in logical_forms:
  if logical_form is None:
    continue
  parser_state = ParserState.from_logical_form(logical_form, no_repeats_parse_options)
  round_trip_logical_form = parser_state.to_logical_form()
  assert logical_form == round_trip_logical_form, "{} != {}".format(logical_form, round_trip_logical_form)


We'll now evaluate this `DummyModel` as a very simple baseline, using [geoqa.evaluation.evaluate_predictions](https://github.com/dpfried/geoqa-release/blob/master/geoqa/evaluation.py), in the two cells below. You might be surprised at the accuracy; a few structural and lexical constraints can go a long way on this dataset, and it's also sometimes possible to get the correct denotation through an incorrect logical form (this will complicate things later on, when we learn from denotations alone). However, examining some instances will show that this baseline model has very poor performance on complex examples.

Note: some of the questions are yes/no questions (e.g. _is virginia east of west virginia_). If the denotation for a question is empty (`set()`), it should be interpreted as a "no" answer; if the denotation is non-empty (`{'virginia'}`) it should be interpreted as a yes. A different version of this dataset, introduced by [Andreas et al. 2016](https://arxiv.org/abs/1601.01705), adds an `any` operation to the logical forms for yes/no questions, which maps empty denotations to "no" and non-empty to "yes", but this makes training from denotations (which we'll do later on) more difficult.

In [0]:
def make_greedy_prediction_function(model):
  # helper function to make a prediction function for evaluate_predictions for a model
  def prediction_function(instance):
    state, _ = greedy_search(model, instance)
    if state.parser_state.is_complete:
      return state.parser_state.to_logical_form()
    else:
      None
  return prediction_function

In [0]:
dummy_model = DummyModel(max_vars=4, no_repeated_literals=True)
evaluate_predictions(
  GeoDataset(DEV_STATES, word_to_embedding), 'dev', execute,
  prediction_function=make_greedy_prediction_function(dummy_model),
  display_predictions_frequency=5
)

dev example 1
dev question: what states are there ?
dev true LF: (lambda $w (and (state $w)))
dev pred LF: (lambda $w (and (state $w)))
dev true denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev pred denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev denotation match: True

dev example 6
dev question: what is the capital of virginia ?
dev true LF: (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
dev pred LF: (lambda $w (and (capital $w)))
dev true denotation: {'richmond'}
dev pred denotation: set()
dev denotation match: False

dev example 11
dev question: what cities in west virginia ?
dev true LF: (lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-west_virginia $x))))
dev pred LF: (lambda $w (and (city $w)))
dev true denotation: set()
dev pred denotation: {'richmond', 'virginia_beach'}
dev denotation match: False

dev example 16
dev question: what state is south of west virginia ?
dev true LF: (lambda $w (exists $x (and (state $w) (sou

{'dev_denotation_acc': 0.4146341463414634}

Now we'll implement a neural model to train on the dataset. We've defined the encoder for you in the cell below, following one possible implementation for the encoder from project 2. The dataset is small enough that batching and using a GPU are unnecessary, so we will leave them out for simplicity.

In [0]:
class Encoder(nn.Module):
  def __init__(self, hidden_dim, word_vector_dim, dropout):
    # word_vectors: vocab_size x dim
    super(Encoder, self).__init__()
    self.hidden_dim = hidden_dim
    self.word_vector_dim = word_vector_dim
    self.dropout = nn.Dropout(dropout)
    self.lstm = nn.LSTM(self.word_vector_dim, hidden_dim,
                        bidirectional=True, batch_first=True)

  def forward(self, embedded_sentence):
    # embedded_sentence: T x word_vector_dim
    lstm_out, (h_n, c_n) = self.lstm(embedded_sentence.unsqueeze(0))

    # eliminate batch dimension
    h_n = h_n.squeeze(0)
    c_n = c_n.squeeze(0)

    # n_layers x directions x T
    h_n = h_n.view(1, 2, -1).mean(1)
    c_n = c_n.view(1, 2, -1).mean(1)

    # eliminate batch dimension
    lstm_out = lstm_out.squeeze(0)
    return self.dropout(lstm_out), (h_n, c_n)


Before building the decoder, we'll define some helper functions to work with probability distributions over actions in log space.

In [0]:
# helper functions to deal with probabilities in log space
def sum_log_probs(log_probs):
  # log_probs: List[tensor]
  # compute \log \sum_x \exp x
  return torch.logsumexp(torch.stack(
    list(log_probs), dim=-1
  ), -1)

def renormalize_log_prob_dict(log_prob_dict):
  # log_prob_dict: Dict[str, tensor]
  # rescale the log probs so that when exponentiated, they sum to 1
  if not log_prob_dict:
    return log_prob_dict
  Z = sum_log_probs(log_prob_dict.values())
  return {k: v - Z for k, v in log_prob_dict.items()}

def filter_log_prob_dict(log_prob_dict, valid_actions):
  return {
    k: v for k, v in log_prob_dict.items()
    if k in valid_actions
  }

def is_normalized(log_prob_dict):
  # log_prob_dict: Dict[str, tensor]
  # check that this dictionary is normalized in log space
  # e.g. as returned by renormalize_log_prob_dict
  if not log_prob_dict:
    return True
  return torch.allclose(
    sum_log_probs(log_prob_dict.values()).exp(),
    torch.tensor(1.0),
    atol=1e-3
  )

You'll now define part of the Decoder. You'll implement an attention-based pointer-generator model (based roughly on section 3.2 of [Jia and Liang 2016](https://arxiv.org/pdf/1606.03622.pdf) and section 2.2 of [See et al. 2017](https://arxiv.org/pdf/1704.04368.pdf)), which is useful for sequence-to-sequence tasks where parts of the output are copied from the input (in our case, predicates will be copied from the tokens they are associated with). 

Actions can be produced either through a `generate` action (given by predicting them from a standard projection layer, similar to the vocab prediction in the previous assignments) or a `copy` action: "pointing" to them in the set of predicates associated with each token in the question, using an attention mechanism. 

An action is produced by marginalizing over a binary choice to `copy` or `generate`, with $p(copy) = 1 - p(generate)$:

$p(action) = p(action | generate) * p(generate) + p(action | copy) * p(copy)$

`STRUCTURAL_ACTIONS` can only be produced by generating. `DATABASE_SPECIFIC_PREDICATE_ACTIONS` can only be produced by copying (since all of these will be unseen when evaluating on a new database). `CROSS_DATABASE_PREDICATE_ACTIONS` can be produced either by generating or copying.


You'll implement this in the `forward()` method of `Decoder`, computing $p(action)$ for each action in `actions_to_score` by combining the probabilities as above. Below are details about the generation mechanism (which we've implemented) and copy mechanism (which you'll implement).

__Generation mechanism__

$p(action | generate)$ is produced using a standard projection layer similar to the vocab prediction in the previous assignments, assigning scores to actions in `STRUCTURAL_ACTIONS` + `CROSS_DATABASE_PREDICATE_ACTIONS`. We've implemented this for you, in `_generate_log_probs()`.

__Copy mechanism__

Actions for predicates that are associated with words in the sentence can be produced by "pointing" to them with an attention mechanism. We will reuse the decoder's attention mechanism, masked and renormalized to only be over those tokens that have associated predicates, to give a distribution $p(position | copy)$. A given predicate can appear with multiple tokens, so marginalize over the positions in the sentence:

$p(action | copy) = \sum_{position} p(action | position, copy) p(position | copy)$

A given position can have zero, one, or multiple predicates associated with it (e.g. the token "ocean" can have both the `ocean` or the `kb-atlantic_ocean` predicates associated with it); produce a distribution over the predicates at each position using a softmax of the dot product between a `predicate_key` vector (a projection of the decoder state, which we've computed for you) and the pretrained embedding for the action's predicate:

$p(action | position, copy) = softmax(\texttt{predicate_embedding}^\top \texttt{predicate_key})$

You'll implement this in the `.copy_log_probs()` method in `Decoder` below.

Fill in the `._copy_log_probs()` and `.forward()` methods in the `Decoder` class below.

__Implementation tip__: Do probability calculations in log space: `p*q` can be computed using `+` when values are represented as log probabilities. In the same way, the `torch.logsumexp` function computes the sum `p+q` when `p` and `q` (and the returned value) are represented as log probabilities.

In [0]:
import torch.nn.functional as F

In [0]:
class Decoder(nn.Module):
  START_ACTION = "<START>"
  DATABASE_SPECIFIC_ACTION = "<KB_ENTITY>"

  def __init__(self, hidden_dim, action_embedding_dim, input_action_vec_dim, dropout):
    super(Decoder, self).__init__()
    self.hidden_dim = hidden_dim
    self.action_embedding_dim = action_embedding_dim
    self.input_action_vec_dim = input_action_vec_dim

    self.action_index = Index()
    generable_indices = []
    for action in STRUCTURAL_ACTIONS:
      generable_indices.append(self.action_index.index(action))
    for action in CROSS_DATABASE_PREDICATE_ACTIONS:
      generable_indices.append(self.action_index.index(action))
    self.generable_indices = generable_indices
    self.START_INDEX = self.action_index.index(Decoder.START_ACTION)
    self.DATABASE_SPECIFIC_ACTION_INDEX = self.action_index.index(Decoder.DATABASE_SPECIFIC_ACTION)
    self.action_index.frozen = True

    self.action_embeddings = nn.Embedding(self.action_index.size(), self.action_embedding_dim)

    self.dropout = nn.Dropout(dropout)
    self.lstm_cell = nn.LSTMCell(self.action_embedding_dim, self.hidden_dim)

    self.action_scoring_layer = nn.Linear(hidden_dim, self.input_action_vec_dim)

    self.encoder_projection = nn.Linear(hidden_dim * 2, hidden_dim)
    self.copy_or_generate_layer = nn.Linear(hidden_dim, 2)
    self.generate_layer = nn.Linear(hidden_dim, len(generable_indices))

  def _pre_prediction(self, encoder_output, last_action, last_hidden):
    if last_action in self.action_index:
      action_index = self.action_index.index(last_action)
    elif last_action is None:
      action_index = self.START_INDEX
    else:
      action_index = self.DATABASE_SPECIFIC_ACTION_INDEX

    embedded_action = self.action_embeddings(torch.LongTensor([action_index]))
    h, c = self.lstm_cell(embedded_action, last_hidden)
    hidden_state = (h, c)
    decoder_out = h.squeeze(0)
    decoder_out = self.dropout(decoder_out)

    encoder_output_projected = self.encoder_projection(encoder_output)

    attention_logits = torch.einsum("th,h->t", [encoder_output_projected, decoder_out])
    attention_dist = torch.softmax(attention_logits, dim=-1)
    pooled_inputs = torch.einsum("th,t->h", [encoder_output_projected, attention_dist])

    output_with_attention = decoder_out + pooled_inputs

    return hidden_state, output_with_attention, attention_logits

  def _position_attention(self, predicates_at_each_word_position, attention_logits):
    positions_without_actions = [ix for ix, l in enumerate(predicates_at_each_word_position) if not l]
    position_attention_logits = attention_logits.clone()
    position_attention_logits[positions_without_actions] = BIG_NEG
    if len(positions_without_actions) == len(predicates_at_each_word_position):
      raise ValueError('no valid positions with predicates to copy to')
    return torch.log_softmax(position_attention_logits, dim=-1)

  def _generate_log_probs(self, actions_to_score, output_with_attention):
    # return log p(action | generate) for actions in actions_to_score,
    # as a dictionary mapping actions to torch scalars
    # log probabilities will be normalized over all generable actions; *NOT* over just actions_to_score
    # (renormalization will happen at the end of forward(), after all marginalizations are complete)
    logits = self.generate_layer(output_with_attention)
    log_probs = torch.log_softmax(logits, dim=-1)
    return {
      action: log_probs[self.action_index.index(action)]
      for action in actions_to_score
      if action in self.action_index
    }

  def _copy_log_probs(self,
                         predicate_key,
                         predicates_at_each_word_position,
                         embedded_predicates_at_each_word_position,
                         position_attention_log_probs,
                         copyable_actions):
    """

    :param predicate_key:
      tensor of size (embedded_predicate_dim,)
    :param predicates_at_each_word_position:
      List[List[str]], containing the list of predicates associated with each position in the sentence
    :param embedded_predicates_at_each_word_position:
      List[tensor], containing the predicate embeddings associated with each position in the sentence.
      embedded_predicates_at_each_word_position[i].size() == len(predicates_at_each_word_position[i]), embedding_dim
    :param position_attention_log_probs:
      tensor of size (len(sentence)), containing log p(position | copy)
    :param copyable_actions:
      Set[str], containing all predicates that can be copied to in this sentence
    :return:
      log p(action | copy) for actions in actions_to_score,
      as a dictionary mapping actions to torch scalars
      this dictionary should be normalized (in log space), over all copyable actions
      Implementation tip: you shouldn't have to normalize this yourself, it should happen
      by construction
    """
    p_action_given_position =  [F.log_softmax(embedding @ predicate_key, dim = -1) for embedding in embedded_predicates_at_each_word_position] # (len(sentence), len(predicates_at_each_word_position[i]))
    action_prod = [p_action_given_position[i] + position_attention_log_probs[i] for i in range(len(p_action_given_position))] #(len(sentence), number of predicates at each position)
    #position_attention_log_probs[i] will be repeated along the number of predicates at each position

    log_action = {}
    for i,predicate in enumerate(predicates_at_each_word_position):
        for j,action in enumerate(predicate):
            if action not in log_action.keys():
                log_action[action] = [action_prod[i][j]]
            else:
                log_action[action].append(action_prod[i][j])
    
    for action in log_action.keys():
        if action in copyable_actions:
            log_action[action] = sum_log_probs(log_action[action])
        else:
            log_action[action] = torch.tensor(-1e9)

    log_action = renormalize_log_prob_dict(log_action)
    return log_action

  def forward(self,
              encoder_output,
              predicates_at_each_word_position,  # [[action, vector]]
              embedded_predicates_at_each_word_position,
              last_action,
              last_hidden_state,
              actions_to_score):
    """
    :return:
      * action_log_probs: a Dict[str, tensor], mapping each action in actions_to_score to a scalar tensor, giving the action's log probability
      * hidden_state: (h, c), decoder's updated hidden state after taking last_action
    """
    T, hidden_dim = encoder_output.size()

    # hidden_state: (h, c), each a tensor of size (hidden_dim,)
    # output_with_attention: tensor of size (hidden_dim,)
    # attention_logits: tensor of size (T,)
    hidden_state, output_with_attention, attention_logits = self._pre_prediction(
      encoder_output, last_action, last_hidden_state
    )

    all_predicates = set([p for ps in predicates_at_each_word_position for p in ps])
    #

    # copy_lp and generate_lp are each scalar tensors, giving log p(copy) and log p(generate)
    copy_lp, generate_lp = torch.log_softmax(self.copy_or_generate_layer(output_with_attention), dim=-1)

    # generate_action_log_probs. log p(action | generate), for generable actions in actions_to_score
    # a Dict[str, tensor], mapping each action in actions_to_score to a scalar tensor, giving the action's log probability
    generate_action_log_probs = self._generate_log_probs(actions_to_score, output_with_attention)

    # position_attention_log_probs: tensor of size (T,) giving log p(position | copy)
    position_attention_log_probs = self._position_attention(predicates_at_each_word_position, attention_logits)

    # compute log p(action | copy) for actions in all_predicates (the copyable actions)
    # a Dict[str, tensor], mapping each action in all_predicates to a scalar tensor, giving the action's log probability
    copy_action_log_probs = self._copy_log_probs(
      self.action_scoring_layer(output_with_attention),
      predicates_at_each_word_position,
      embedded_predicates_at_each_word_position,
      position_attention_log_probs,
      all_predicates
    )

    if copy_action_log_probs:
      assert is_normalized(copy_action_log_probs)
    # now, filter this down to only those actions we want to score
    copy_action_log_probs = filter_log_prob_dict(copy_action_log_probs, actions_to_score)

    # compute the probability of each action, p(action) as:
    # p(action) = p(action | generate) * p(generate) + p(action | copy) * p(copy)
    # p(position | copy) is given by position_attention_log_probs (a tensor)
    # p(action | generate) and p(action | copy) are given by generate_action_log_probs and copy_action_log_probs (dictionaries), above

    action_log_probs = {}
    for action in actions_to_score:
        if action not in copy_action_log_probs.keys():
            copy_action_log_probs[action] = torch.tensor(-1e9)
        if action not in generate_action_log_probs.keys():
            generate_action_log_probs[action] = torch.tensor(-1e9)
        L = [copy_lp + copy_action_log_probs[action], generate_lp + generate_action_log_probs[action]]
        action_log_probs[action] = torch.logsumexp(torch.stack(L, -1), -1)

    assert all(a in action_log_probs for a in actions_to_score)
    # since this contains only a subset of actions, renormalize (in log-space)
    filtered = renormalize_log_prob_dict(
      filter_log_prob_dict(action_log_probs, valid_actions=actions_to_score)
    )
    for v in filtered.values():
      assert v > -1e8
    return filtered, hidden_state

The cell below defines a `Model` class using this `Encoder` and `Decoder`, with the same interface we used in `DummyModel`. We've hardcoded some hyperparameters for size and dropout that we found to work well.

In [0]:
class Model(nn.Module):
  def __init__(self, max_vars=len(VARS), no_repeated_literals=True, max_literals=MAX_LITERALS):
    super(Model, self).__init__()
    self.max_vars = max_vars
    self.max_literals = max_literals
    self.no_repeated_literals = no_repeated_literals

    self.hidden_dim = 50
    self.action_embedding_dim = 50
    self.dropout = 0.2

    self.encoder = Encoder(self.hidden_dim, EMBEDDING_DIM, self.dropout)
    self.decoder = Decoder(self.hidden_dim, self.action_embedding_dim, EMBEDDING_DIM, self.dropout)
    self.max_vars = max_vars

  def initialize_model_state(self, instance):
    parse_constraints = ParseConstraints(
      possible_predicates=instance['possible_predicates'],
      max_vars=self.max_vars,
      no_repeated_literals=self.no_repeated_literals,
      max_literals=self.max_literals,
    )
    parser_state = ParserState.initial_state(parse_constraints)

    encoder_output, encoder_hidden = self.encoder(instance['embedded_words'])

    # define a decoder variable so that we can reference it inside the ModelState class without using self;
    decoder = self.decoder

    # embed this class so that we don't have to reencode the input every time we take an action
    class ModelState(_ModelState):
      @staticmethod
      def _create_from_hidden_and_action(last_hidden_state, last_action, parser_state: ParserState):
        possible_actions = parser_state.valid_actions()
        action_log_probs, hidden = decoder(
          encoder_output,
          instance['predicates_at_each_word_position'],
          instance['embedded_predicates_at_each_word_position'],
          last_action,
          last_hidden_state,
          actions_to_score=possible_actions
        )
        return ModelState(parser_state, hidden, action_log_probs)

      def take_action(self, action: str):
        parser_state = self.parser_state.take_action(action)
        return ModelState._create_from_hidden_and_action(self.hidden_state, action, parser_state)

    # the decoder expects None to start the action sequence
    last_action = None
    return ModelState._create_from_hidden_and_action(encoder_hidden, last_action, parser_state)

## Supervised Training

We'll first assume a supervised setting, where logical forms are provided for each question and used to directly supervise the model outputs. Later on, we'll relax this assumption and train from questions and denotations (answers) alone, using latent logical forms. In the two cells below, we've provided training code that will use the model and parser states you've defined.

In [0]:
def make_beam_prediction_function(model, beam_size=5):
  # helper function to make a prediction function for evaluate_predictions for a model
  def prediction_function(instance):
    candidates = beam_search(model, instance, beam_size)
    if not candidates:
      return None
    state, _ = candidates[0]
    if state.parser_state.is_complete:
      return state.parser_state.to_logical_form()
    else:
      return None
  return prediction_function

In [0]:
def train(model, train_dataset, dev_dataset, model_file, num_epochs=10, latent_logical_forms=False,
          training_beam_size=None, learning_rate=1e-3, display_predictions_frequency=None):
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  stats_by_epoch = {}

  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=lambda x: x)

  if latent_logical_forms:
    assert training_beam_size is not None

  if not latent_logical_forms:
    loss_function_for_eval = lambda instance: supervised_loss(model, instance).item()
  else:
    loss_function_for_eval = None

  prediction_function = make_beam_prediction_function(model, 5)

  best_metric = 0.0
  for epoch in tqdm.notebook.trange(num_epochs, desc="training", unit="epoch"):
    with tqdm.notebook.tqdm(
        train_dataloader,
        desc="epoch {}".format(epoch + 1),
        unit="instance",
        total=len(train_dataloader)
    ) as batch_iterator:
      model.train()

      total_num_correct_lfs = 0
      total_num_lfs = 0

      total_loss = 0.0

      for i, batch in enumerate(batch_iterator, start=1):
        assert len(batch) == 1
        instance = batch[0]

        optimizer.zero_grad()
        if latent_logical_forms:
          loss, num_correct_lfs, num_lfs = latent_loss(model, instance, training_beam_size)
          total_num_lfs += num_lfs
          total_num_correct_lfs += num_correct_lfs
          batch_iterator.set_postfix(
            train_correct_candidates_per_instance=total_num_correct_lfs/total_num_lfs*training_beam_size,
            train_loss=total_loss / i
          )
        else:
          loss = supervised_loss(model, instance)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
      model.eval()
      train_stats = evaluate_predictions(
        train_dataset, 'train', execute,
        prediction_function=prediction_function,
        loss_function=loss_function_for_eval,
        display_predictions_frequency=display_predictions_frequency
      )
      dev_stats = evaluate_predictions(
        dev_dataset, 'dev', execute,
        prediction_function=prediction_function,
        loss_function=loss_function_for_eval,
        display_predictions_frequency=display_predictions_frequency
      )
      dev_metric = dev_stats['dev_denotation_acc']
      batch_iterator.set_postfix({'train_loss': (total_loss / len(batch_iterator)), **train_stats, **dev_stats})
      if dev_metric > best_metric:
        best_epoch = epoch
        print("Obtained a new best development accuracy of {:.3f}, saving model "
              "checkpoint to {}...".format(dev_metric, model_file))
        torch.save(model.state_dict(), model_file)
        best_metric = dev_metric
      stats_by_epoch[epoch] = dev_stats
  print("Maximal development accuracy of {:.3f}".format(best_metric))
  print("Reloading best model checkpoint from {}...".format(model_file))
  model.load_state_dict(torch.load(model_file))

The supervised loss takes each action in the traversal of the correct logical form, collecting the log probabilities of these actions along the way.

In [0]:
def supervised_loss(model, instance):
  state = model.initialize_model_state(instance)
  gold_actions = ParserState.from_logical_form(instance['logical_form']).actions()

  loss = torch.tensor(0.0)
  for action in gold_actions:
    loss += -1.0 * state.action_log_probs[action]
    state = state.take_action(action)

  return loss

In [0]:
train_dataset = GeoDataset(TRAIN_STATES, word_to_embedding)
dev_dataset = GeoDataset(DEV_STATES, word_to_embedding)
test_dataset = GeoDataset(TEST_STATES, word_to_embedding)

print('{} train instances'.format(len(train_dataset)))
print('{} dev instances'.format(len(dev_dataset)))
print('{} test instances'.format(len(test_dataset)))


133 train instances
41 dev instances
51 test instances


In [0]:
supervised_model = Model()

We'll now train the model in the cell below. You should obtain a maximal denotation accuracy of at least 78% on the development set. This is a very small dataset and will have some variance. With our reference implementation we obtained 82-83% maximal dev denotation accuracy on average, with a range of 78% to 85% (over 40 trials).

In [0]:
train(supervised_model, train_dataset, dev_dataset, 'supervised_model.pt', num_epochs=10)

HBox(children=(FloatProgress(value=0.0, description='training', max=10.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='epoch 1', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.780, saving model checkpoint to supervised_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 2', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 3', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 4', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.805, saving model checkpoint to supervised_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 5', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 6', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 7', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 8', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 9', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 10', max=133.0, style=ProgressStyle(description_wid…



Maximal development accuracy of 0.805
Reloading best model checkpoint from supervised_model.pt...


We'll now view the model predictions on the dev set.

In [0]:
evaluate_predictions(
  dev_dataset, 'dev', execute,
  prediction_function=make_beam_prediction_function(supervised_model),
  display_predictions_frequency=5
)

dev example 1
dev question: what states are there ?
dev true LF: (lambda $w (and (state $w)))
dev pred LF: (lambda $w (and (state $w) (state $w)))
dev true denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev pred denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev denotation match: True

dev example 6
dev question: what is the capital of virginia ?
dev true LF: (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
dev pred LF: (lambda $w (exists $x (and (capital $w) (capital-rel $w $x) (kb-virginia $x))))
dev true denotation: {'richmond'}
dev pred denotation: set()
dev denotation match: False

dev example 11
dev question: what cities in west virginia ?
dev true LF: (lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-west_virginia $x))))
dev pred LF: (lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-virginia $x))))
dev true denotation: set()
dev pred denotation: {'richmond', 'virginia_beach'}
dev denotation match: False

dev example 16
dev 

{'dev_denotation_acc': 0.8048780487804879}

## Training with Latent Logical Forms

Now, we will relax the assumption that we have logical forms available, and train our model only from question--answer pairs, the databases, and the execution function. This setting, _training from denotations_, requires less supervision but introduces problems of search, delayed reward, and ambiguity, as there are many possible logical forms that can execute to a given denotation. We'll use a classic method, _maximum marginal likelihood_; if you're interested in other methods, see [Misra et al. 2018](https://dipendramisra.com/papers/mchy-emnlp.2018.pdf) and [Guu et al. 2017](https://arxiv.org/abs/1704.07926) for a survey and connections to policy gradient from reinforcement learning.

Let $x$ be a question, $z$ be the denotation for the question, and $d$ be a database. Our model gives a distribution $p_\theta(y \mid x)$ for logical forms $y$.  Maximum marginal likelihood (approximately) maximizes the log probability of producing the correct denotation $z$, marginalized over logical forms. In practice, we can't sum over all possible logical forms, so we'll approximate using a set of logical forms $\mathcal{Y}$ returned by beam search:

$$J_{MML} = \log p(z \mid x, d) \approx \log \sum_{y \in \mathcal{Y}} p(z \mid y, d) p_\theta(y \mid x)$$

Our execution function (given by the `execute` function we've defined above) is deterministic, and we have $p(z \mid y, d) = 1$ or $0$ depending on whether $y$ denotes $z$ in the database. We can use this fact to simplify the objective: letting $\mathcal{C}$ be the subset of _correct_ logical forms $\mathcal{Y}$, that evaluate to $z$, we have

$$J_{MML} = \log \sum_{y \in \mathcal{C}} p_\theta(y \mid x)$$

Implement this in the `latent_loss` method below, returning $-J_{MML}$ as the loss, as well as the number of logical forms in $\mathcal{C}$ (`num_correct`) and the number of logical forms found by search (`num_instances`).

__Implementation tip__: for numerical stability, perform calculations in log space using the logical forms' log probabilities and `torch.logsumexp`.

In [0]:
def latent_loss(model, instance, beam_size):
  true_denotation = instance['denotation']
  world = instance['world']

  # contains tuples of model states and total log probabilities returned by beam search: [(model_state: ModelState, log_prob: tensor)]
  # some model states may correspond to incomplete logical forms, as when repeated literals are not allowed, it's possible for the model to reach a state where no valid actions are available
  # use model_state.parser_state.is_complete to check whether a parser state corresponds to a complete logical form, before attempting to convert it to a logical form
  candidates = beam_search(model, instance, beam_size)

  num_correct = 0
  num_instances = len(candidates)

  L = []
  for model_state, total_log_prob in candidates:
    if model_state.parser_state.is_complete:
      logical_form = model_state.parser_state.to_logical_form()
      denotation = execute(logical_form, world)
      if denotation == true_denotation:
          num_correct += 1
          L.append(total_log_prob)
    else:
        num_instance -= 1
  if len(L) > 0:
    loss = - torch.logsumexp(torch.stack(L), -1)
  else:
    loss = - np.inf

  if num_correct == 0:
    # no update if no correct examples were found; define a Variable to not break the backprop and update code
    return torch.autograd.Variable(torch.tensor(0.0), requires_grad=True), num_correct, num_instances
  else:
    return loss, num_correct, num_instances


You can sanity check your loss by training and evaluating on a restricted version of the dataset which contains only logical forms with at most two variables and two literals. These examples will be a bit easier, and the constrains on logical forms will allow us to reduce the search space and train faster.

In [0]:
def is_complex_lf(logical_form):
  if len(ParserState.from_logical_form(logical_form).variables_introduced()) > 2:
    return True
  if ParserState.from_logical_form(logical_form).completed_literal_stack.size > 2:
    return True
  return False

simple_train_dataset = GeoDataset(TRAIN_STATES, word_to_embedding, logical_form_exclude_function=is_complex_lf)
simple_dev_dataset = GeoDataset(DEV_STATES, word_to_embedding, logical_form_exclude_function=is_complex_lf)
print('{} train instances with maximum 2 variables and 2 literals'.format(len(simple_train_dataset)))
print('{} dev instances with maximum 2 variables and 2 literals'.format(len(simple_dev_dataset)))
print()
for i in range(3):
  print_instance(simple_dev_dataset[i])
  print()

40 train instances with maximum 2 variables and 2 literals
9 dev instances with maximum 2 variables and 2 literals

question           : what states are there ?
answer             : virginia,west_virginia,north_carolina
logical_form       : (lambda $w (and (state $w)))
world              : <World: va with 6 entities, 6 categories, 72 relations>
possible_predicates: ['state']
denotation         : {'virginia', 'west_virginia', 'north_carolina'}

question           : what 's the capital of virginia ?
answer             : richmond
logical_form       : (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
world              : <World: va with 6 entities, 6 categories, 72 relations>
possible_predicates: ['capital', 'capital-rel', 'kb-virginia', 'kb-virginia_beach', 'kb-west_virginia']
denotation         : {'richmond'}

question           : what oceans are there ?
answer             : atlantic_ocean
logical_form       : (lambda $w (and (ocean $w)))
world              : <World: va 

In [0]:
simple_latent_model = Model(max_vars=2, max_literals=2)

The following cell will train your model from denotations alone on these simplified logical forms. 

With beam size 100, if your `latent_loss` implementation is correct, you should expect to see `train_loss` values roughly between 0.5 and 2 at the end of 10 epochs (note that the magnitude of the loss will change with varying number of candidates).

Our implementation obtains _training_ denotation accuracies between 80% and 100%, and dev denotation accuracies of at least 60% by the end of 10 epochs (but note there are only 9 dev examples when filtering the logical forms in this way! so don't be surprised if this score varies widely).

In [0]:
train(simple_latent_model, simple_train_dataset, simple_dev_dataset, 'simple_latent_model.pt', 
      num_epochs=10, latent_logical_forms=True, 
      training_beam_size=100)

HBox(children=(FloatProgress(value=0.0, description='training', max=10.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='epoch 1', max=40.0, style=ProgressStyle(description_width…


Obtained a new best development accuracy of 0.556, saving model checkpoint to simple_latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 2', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 3', max=40.0, style=ProgressStyle(description_width…


Obtained a new best development accuracy of 1.000, saving model checkpoint to simple_latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 4', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 5', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 6', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 7', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 8', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 9', max=40.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='epoch 10', max=40.0, style=ProgressStyle(description_widt…



Maximal development accuracy of 1.000
Reloading best model checkpoint from simple_latent_model.pt...


In [0]:
evaluate_predictions(
  simple_dev_dataset, 'dev', execute,
  prediction_function=make_beam_prediction_function(simple_latent_model),
  display_predictions_frequency=1
)

dev example 1
dev question: what states are there ?
dev true LF: (lambda $w (and (state $w)))
dev pred LF: (lambda $w (and (state $w) (state $w)))
dev true denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev pred denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev denotation match: True

dev example 2
dev question: what 's the capital of virginia ?
dev true LF: (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
dev pred LF: (lambda $x (exists $w (and (kb-virginia $w) (capital-rel $x $w))))
dev true denotation: {'richmond'}
dev pred denotation: {'richmond'}
dev denotation match: True

dev example 3
dev question: what oceans are there ?
dev true LF: (lambda $w (and (ocean $w)))
dev pred LF: (lambda $w (and (ocean $w) (ocean $w)))
dev true denotation: {'atlantic_ocean'}
dev pred denotation: {'atlantic_ocean'}
dev denotation match: True

dev example 4
dev question: what is the capital of virginia ?
dev true LF: (lambda $w (exists $x (and (capital-rel

{'dev_denotation_acc': 1.0}

Now, we'll proceed to training on the full dataset with an unrestricted model. The search space is larger here and it may take several epochs before performance starts to increase. While the training loss should consistently decrease, you may see fluctuations in denotation accuracy (both on train and dev) from epoch to epoch. However, you'll likely obtain training denotation accuracies >80% and dev denotation accuracies >70% at some point by epoch 5. Your maximal dev accuracy should be at least 63%, although there is again some variance due to the small dataset size, which is compounded by the variance from training from denotations. Our reference implementation obtained an average maximal dev accuracy of 80%, but ranged between 63% and 90% (over 40 trials). 10 epochs of training should take around 20-30 minutes.

In [0]:
latent_model = Model()

In [0]:
train(latent_model, train_dataset, dev_dataset, 'latent_model.pt', 
      num_epochs=10, latent_logical_forms=True, 
      training_beam_size=100)

HBox(children=(FloatProgress(value=0.0, description='training', max=10.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='epoch 1', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.439, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 2', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.659, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 3', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.732, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 4', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 5', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.780, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 6', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 7', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.805, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 8', max=133.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='epoch 9', max=133.0, style=ProgressStyle(description_widt…


Obtained a new best development accuracy of 0.854, saving model checkpoint to latent_model.pt...


HBox(children=(FloatProgress(value=0.0, description='epoch 10', max=133.0, style=ProgressStyle(description_wid…



Maximal development accuracy of 0.854
Reloading best model checkpoint from latent_model.pt...


You can run the following cell to see the model predictions:

In [0]:
evaluate_predictions(
  dev_dataset, 'dev', execute,
  prediction_function=make_beam_prediction_function(latent_model),
  display_predictions_frequency=5
)

dev example 1
dev question: what states are there ?
dev true LF: (lambda $w (and (state $w)))
dev pred LF: (lambda $w (exists $x (and (state $w) (state $x) (state $w))))
dev true denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev pred denotation: {'virginia', 'west_virginia', 'north_carolina'}
dev denotation match: True

dev example 6
dev question: what is the capital of virginia ?
dev true LF: (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
dev pred LF: (lambda $w (exists $x (and (capital-rel $w $x) (kb-virginia $x))))
dev true denotation: {'richmond'}
dev pred denotation: {'richmond'}
dev denotation match: True

dev example 11
dev question: what cities in west virginia ?
dev true LF: (lambda $w (exists $x (and (city $w) (in-rel $w $x) (kb-west_virginia $x))))
dev pred LF: (lambda $w (exists $x (and (west-rel $w $x) (kb-virginia $x) (city $w))))
dev true denotation: set()
dev pred denotation: set()
dev denotation match: True

dev example 16
dev question

{'dev_denotation_acc': 0.8536585365853658}

## Submission

In [0]:
def dump_logical_form_actions(dataset, fname):
  with open(fname, 'w') as f:
    for i in range(len(dataset)):
      parser_state = ParserState.from_logical_form(dataset[i]['logical_form'])
      f.write('{}\n'.format(' '.join(parser_state.actions())))

In [0]:
# to test your ParserState.take_action() method, we'll also save 
# the actions for the true logical forms
dump_logical_form_actions(dev_dataset, './dev_true_lf_actions.txt')
dump_logical_form_actions(test_dataset, './test_true_lf_actions.txt')

In [365]:
# Uncomment this code to restore models from a checkpoint.
#
supervised_model = Model(max_vars=4, no_repeated_literals=True)
supervised_model.load_state_dict(torch.load('supervised_model.pt'))
latent_model = Model(max_vars=4, no_repeated_literals=True)
latent_model.load_state_dict(torch.load('latent_model.pt'))

pprint.pprint(evaluate_predictions(
  dev_dataset, 'supervised_dev', execute,
  prediction_function=make_beam_prediction_function(supervised_model, 5),
  logical_form_output_file='./supervised_dev_lfs.txt'
))
pprint.pprint(evaluate_predictions(
  test_dataset, 'supervised_test', execute,
  prediction_function=make_beam_prediction_function(supervised_model, 5),
  logical_form_output_file='./supervised_test_lfs.txt'
))
pprint.pprint(evaluate_predictions(
  dev_dataset, 'latent_dev', execute,
  prediction_function=make_beam_prediction_function(latent_model, 5),
  logical_form_output_file='./latent_dev_lfs.txt'
))
pprint.pprint(evaluate_predictions(
  test_dataset, 'latent_test', execute,
  prediction_function=make_beam_prediction_function(latent_model, 5),
  logical_form_output_file='./latent_test_lfs.txt'
))

{'supervised_dev_denotation_acc': 0.8048780487804879}
{'supervised_test_denotation_acc': 0.6862745098039216}
{'latent_dev_denotation_acc': 0.8292682926829268}
{'latent_test_denotation_acc': 0.7058823529411765}


Your submission should consist of your code, the actions to generate true logical forms for the dev and test set (to test your ParserState.take_action() method), as well as the logical forms predicted by your model for both dev and test data in supervised and latent settings. Turn in the following files on Gradescope:

- proj_4.ipynb (this file; please rename to match)
- dev_true_lf_actions.txt
- test_true_lf_actions.txt
- supervised_dev_lfs.txt
- supervised_test_lfs.txt
- latent_dev_lfs.txt
- latent_test_lfs.txt

Be sure to check the output of the autograder after it runs. It should confirm that no files are missing and that the output files have the correct format.

In [0]:
from google.colab import files

files.download('dev_true_lf_actions.txt')
files.download('latent_dev_lfs.txt')
files.download('latent_test_lfs.txt')
files.download('supervised_dev_lfs.txt')
files.download('supervised_test_lfs.txt')
files.download('test_true_lf_actions.txt')