Skip to content
Permalink
Browse files

Added coverage to WikiTables ERM parser (#1181)

* wikitables coverage parser

* minor comment updates and bug fixes

* updated table question kg test and coverage computation

* removed unused imports

* copy decoder step's output projection weights properly

* fixed new variable creation

* addressed pr comments

* mypy issues

* addressed remaining PR comments
  • Loading branch information...
pdasigi committed May 10, 2018
1 parent 69ae074 commit 8ba58675175e91d306f55380833458acfcb38cdd
@@ -121,6 +121,9 @@ class WikiTablesDatasetReader(DatasetReader):
usage of the table representations, truncating cells with really long text. We specify a
total number of tokens, not a max cell text length, because the number of table entities
varies.
output_agendas : ``bool``, (optional, default=False)
Should we output agenda fields? This needs to be true if you want to train a coverage based
parser.
"""
def __init__(self,
lazy: bool = False,
@@ -136,7 +139,8 @@ def __init__(self,
use_table_for_vocab: bool = False,
linking_feature_extractors: List[str] = None,
include_table_metadata: bool = False,
max_table_tokens: int = None) -> None:
max_table_tokens: int = None,
output_agendas: bool = False) -> None:
super().__init__(lazy=lazy)
self._tables_directory = tables_directory
self._dpd_output_directory = dpd_output_directory
@@ -152,6 +156,7 @@ def __init__(self,
self._include_table_metadata = include_table_metadata
self._basic_types = set(str(type_) for type_ in wt_types.BASIC_TYPES)
self._max_table_tokens = max_table_tokens
self._output_agendas = output_agendas

@overrides
def _read(self, file_path: str):
@@ -297,13 +302,12 @@ def text_to_instance(self, # type: ignore
if example_lisp_string:
fields['example_lisp_string'] = MetadataField(example_lisp_string)

# We'll make each target action sequence a List[IndexField], where the index is into
# the action list we made above. We need to ignore the type here because mypy doesn't
# like `action.rule` - it's hard to tell mypy that the ListField is made up of
# ProductionRuleFields.
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore
if dpd_output:
# We'll make each target action sequence a List[IndexField], where the index is into
# the action list we made above. We need to ignore the type here because mypy doesn't
# like `action.rule` - it's hard to tell mypy that the ListField is made up of
# ProductionRuleFields.
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore

action_sequence_fields: List[Field] = []
for logical_form in dpd_output:
if not self._should_keep_logical_form(logical_form):
@@ -344,6 +348,13 @@ def text_to_instance(self, # type: ignore
# full test data.
return None
fields['target_action_sequences'] = ListField(action_sequence_fields)
if self._output_agendas:
agenda_index_fields: List[Field] = []
for agenda_string in world.get_agenda():
agenda_index_fields.append(IndexField(action_map[agenda_string], action_field))
if not agenda_index_fields:
agenda_index_fields = [IndexField(-1, action_field)]
fields['agenda'] = ListField(agenda_index_fields)
return Instance(fields)

def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
@@ -380,16 +391,21 @@ def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
'actions': action_field,
'example_lisp_string': example_string_field}

if 'target_action_sequences' in json_obj:
if 'target_action_sequences' in json_obj or 'agenda' in json_obj:
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore
if 'target_action_sequences' in json_obj:
action_sequence_fields: List[Field] = []
for sequence in json_obj['target_action_sequences']:
index_fields: List[Field] = []
for production_rule in sequence:
index_fields.append(IndexField(action_map[production_rule], action_field))
action_sequence_fields.append(ListField(index_fields))
fields['target_action_sequences'] = ListField(action_sequence_fields)

if 'agenda' in json_obj:
agenda_index_fields: List[Field] = []
for agenda_action in json_obj['agenda']:
agenda_index_fields.append(IndexField(action_map[agenda_action], action_field))
fields['agenda'] = ListField(agenda_index_fields)
return Instance(fields)

@staticmethod
@@ -442,6 +458,7 @@ def from_params(cls, params: Params) -> 'WikiTablesDatasetReader':
linking_feature_extracters = params.pop('linking_feature_extractors', None)
include_table_metadata = params.pop_bool('include_table_metadata', False)
max_table_tokens = params.pop_int('max_table_tokens', None)
output_agendas = params.pop_bool('output_agendas', False)
params.assert_empty(cls.__name__)
return WikiTablesDatasetReader(lazy=lazy,
tables_directory=tables_directory,
@@ -456,4 +473,5 @@ def from_params(cls, params: Params) -> 'WikiTablesDatasetReader':
use_table_for_vocab=use_table_for_vocab,
linking_feature_extractors=linking_feature_extracters,
include_table_metadata=include_table_metadata,
max_table_tokens=max_table_tokens)
max_table_tokens=max_table_tokens,
output_agendas=output_agendas)
@@ -53,3 +53,8 @@ def __str__(self):

def __repr__(self):
return self.__str__()

def __eq__(self, other):
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
@@ -11,7 +11,7 @@
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.nn.decoding import DecoderTrainer
from allennlp.nn.decoding import DecoderTrainer, ChecklistState
from allennlp.nn.decoding.decoder_trainers import ExpectedRiskMinimization
from allennlp.nn import util
from allennlp.models.archival import load_archive, Archive
@@ -58,7 +58,7 @@ class NlvrCoverageSemanticParser(NlvrSemanticParser):
and shouldn't be penalized, while we will mostly want to penalize longer logical forms.
max_decoding_steps : ``int``
Maximum number of steps for the beam search during training.
checklist_cost_weight : ``float``, optional (default=0.8)
checklist_cost_weight : ``float``, optional (default=0.6)
Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
denotation accuracy.
@@ -83,7 +83,7 @@ def __init__(self,
beam_size: int,
max_decoding_steps: int,
normalize_beam_score_by_length: bool = False,
checklist_cost_weight: float = 0.8,
checklist_cost_weight: float = 0.6,
dynamic_cost_weight: Dict[str, Union[int, float]] = None,
penalize_non_agenda_actions: bool = False,
initial_mml_model_file: str = None) -> None:
@@ -215,19 +215,17 @@ def forward(self, # type: ignore
label_strings = self._get_label_strings(labels) if labels is not None else None
# Each instance's agenda is of size (agenda_size, 1)
agenda_list = [agenda[i] for i in range(batch_size)]
checklist_targets = []
all_terminal_actions = []
checklist_masks = []
initial_checklist_list = []
initial_checklist_states = []
for instance_actions, instance_agenda in zip(actions, agenda_list):
checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
checklist_target, terminal_actions, checklist_mask = checklist_info
checklist_targets.append(checklist_target)
all_terminal_actions.append(terminal_actions)
checklist_masks.append(checklist_mask)
initial_checklist_list.append(util.new_variable_with_size(checklist_target,
checklist_target.size(),
0))
initial_checklist = util.new_variable_with_size(checklist_target,
checklist_target.size(),
0)
initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions,
checklist_target=checklist_target,
checklist_mask=checklist_mask,
checklist=initial_checklist))
initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
action_history=[[] for _ in range(batch_size)],
score=initial_score_list,
@@ -238,10 +236,7 @@ def forward(self, # type: ignore
possible_actions=actions,
worlds=worlds,
label_strings=label_strings,
terminal_actions=all_terminal_actions,
checklist_target=checklist_targets,
checklist_masks=checklist_masks,
checklist=initial_checklist_list)
checklist_state=initial_checklist_states)

agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
outputs = self._decoder_trainer.decode(initial_state,
@@ -363,21 +358,16 @@ def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor:
"""
if not state.is_finished():
raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
instance_checklist_target = state.checklist_target[0]
instance_checklist = state.checklist[0]
instance_checklist_mask = state.checklist_mask[0]

# Our checklist cost is a sum of squared error from where we want to be, making sure we
# take into account the mask.
checklist_balance = instance_checklist_target - instance_checklist
checklist_balance = checklist_balance * instance_checklist_mask
checklist_balance = state.checklist_state[0].get_balance()
checklist_cost = torch.sum((checklist_balance) ** 2)

# This is the number of items on the agenda that we want to see in the decoded sequence.
# We use this as the denotation cost if the path is incorrect.
# Note: If we are penalizing the model for producing non-agenda actions, this is not the
# upper limit on the checklist cost. That would be the number of terminal actions.
denotation_cost = torch.sum(instance_checklist_target.float())
denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
checklist_cost = self._checklist_cost_weight * checklist_cost
# TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
# how many worlds the logical form is correct in?
@@ -435,7 +425,7 @@ def from_params(cls, vocab, params: Params) -> 'NlvrCoverageSemanticParser':
beam_size = params.pop_int('beam_size')
normalize_beam_score_by_length = params.pop_bool('normalize_beam_score_by_length', False)
max_decoding_steps = params.pop_int("max_decoding_steps")
checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.8)
checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.6)
dynamic_cost_weight = params.pop("dynamic_cost_weight", None)
penalize_non_agenda_actions = params.pop_bool("penalize_non_agenda_actions", False)
initial_mml_model_file = params.pop("initial_mml_model_file", None)
@@ -1,10 +1,8 @@
from typing import List, Dict, Tuple

import torch
from torch.autograd import Variable

from allennlp.data.fields.production_rule_field import ProductionRuleArray
from allennlp.nn.decoding import DecoderState, GrammarState, RnnState
from allennlp.nn.decoding import DecoderState, GrammarState, RnnState, ChecklistState
from allennlp.semparse.worlds import NlvrWorld


@@ -44,30 +42,11 @@ class NlvrDecoderState(DecoderState['NlvrDecoderState']):
String representations of labels for the elements provided. When scoring finished states, we
will compare the denotations of their action sequences against these labels. For each
element, there will be as many labels as there are worlds.
terminal_actions : ``List[torch.Tensor]``, optional
Each element in the list is a vector containing the indices of terminal actions. Currently
the vectors are the same for all instances, because we consider all terminals for each
instance. In the future, we may want to include only world-specific terminal actions here.
Each of these vectors is needed for computing checklists for next states, only if this state
is being while training a parser without logical forms.
checklist_target : ``List[torch.LongTensor]``, optional
List of targets corresponding to agendas that indicate the states we want the checklists to
ideally be. Each element in this list is the same size as the corresponding element in
``agenda_relevant_actions``, and it contains 1 for each corresponding action in the relevant
actions list that we want to see in the final logical form, and 0 for each corresponding
action that we do not. Needed only if this state is being used while training a parser
without logical forms.
checklist_masks : ``List[torch.Tensor]``, optional
Masks corresponding to ``terminal_actions``, indicating which of those actions are relevant
for checklist computation. For example, if the parser is penalizing non-agenda terminal
actions, all the terminal actions are relevant. Needed only if this state is being used
while training a parser without logical forms.
checklist : ``List[Variable]``, optional
A checklist for each instance indicating how many times each action in its agenda has
been chosen previously. It contains the actual counts of the agenda actions. Needed only if
this state is being used while training a parser without logical forms.
checklist_state : ``List[ChecklistState]``, optional (default=None)
If you are using this state within a parser being trained for coverage, we need to store a
``ChecklistState`` which keeps track of the coverage information. Not needed if you are
using a non-coverage based training algorithm.
"""
# TODO(pradeep): Group checklist related pieces into a checklist state.
def __init__(self,
batch_indices: List[int],
action_history: List[List[int]],
@@ -79,17 +58,13 @@ def __init__(self,
possible_actions: List[List[ProductionRuleArray]],
worlds: List[List[NlvrWorld]],
label_strings: List[List[str]],
terminal_actions: List[torch.Tensor] = None,
checklist_target: List[torch.Tensor] = None,
checklist_masks: List[torch.Tensor] = None,
checklist: List[Variable] = None) -> None:
checklist_state: List[ChecklistState] = None) -> None:
super(NlvrDecoderState, self).__init__(batch_indices, action_history, score)
self.rnn_state = rnn_state
self.grammar_state = grammar_state
self.terminal_actions = terminal_actions
self.checklist_target = checklist_target
self.checklist_mask = checklist_masks
self.checklist = checklist
# Converting None to list of Nones if needed, to simplify state operations.
self.checklist_state = checklist_state if checklist_state is not None else [None for _ in
batch_indices]
self.action_embeddings = action_embeddings
self.action_indices = action_indices
self.possible_actions = possible_actions
@@ -116,17 +91,7 @@ def combine_states(cls, states) -> 'NlvrDecoderState':
scores = [score for state in states for score in state.score]
rnn_states = [rnn_state for state in states for rnn_state in state.rnn_state]
grammar_states = [grammar_state for state in states for grammar_state in state.grammar_state]
if states[0].terminal_actions is not None:
terminal_actions = [actions for state in states for actions in state.terminal_actions]
checklist_target = [target_list for state in states for target_list in
state.checklist_target]
checklist_masks = [mask for state in states for mask in state.checklist_mask]
checklist = [checklist_list for state in states for checklist_list in state.checklist]
else:
terminal_actions = None
checklist_target = None
checklist_masks = None
checklist = None
checklist_states = [checklist_state for state in states for checklist_state in state.checklist_state]
return NlvrDecoderState(batch_indices=batch_indices,
action_history=action_histories,
score=scores,
@@ -137,7 +102,4 @@ def combine_states(cls, states) -> 'NlvrDecoderState':
possible_actions=states[0].possible_actions,
worlds=states[0].worlds,
label_strings=states[0].label_strings,
terminal_actions=terminal_actions,
checklist_target=checklist_target,
checklist_masks=checklist_masks,
checklist=checklist)
checklist_state=checklist_states)

0 comments on commit 8ba5867

Please sign in to comment.
You can’t perform that action at this time.