Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Added coverage to WikiTables ERM parser (#1181)
Browse files Browse the repository at this point in the history
* wikitables coverage parser

* minor comment updates and bug fixes

* updated table question kg test and coverage computation

* removed unused imports

* copy decoder step's output projection weights properly

* fixed new variable creation

* addressed pr comments

* mypy issues

* addressed remaining PR comments
  • Loading branch information
pdasigi committed May 10, 2018
1 parent 69ae074 commit 8ba5867
Show file tree
Hide file tree
Showing 20 changed files with 718 additions and 248 deletions.
38 changes: 28 additions & 10 deletions allennlp/data/dataset_readers/wikitables.py
Expand Up @@ -121,6 +121,9 @@ class WikiTablesDatasetReader(DatasetReader):
usage of the table representations, truncating cells with really long text. We specify a usage of the table representations, truncating cells with really long text. We specify a
total number of tokens, not a max cell text length, because the number of table entities total number of tokens, not a max cell text length, because the number of table entities
varies. varies.
output_agendas : ``bool``, (optional, default=False)
Should we output agenda fields? This needs to be true if you want to train a coverage based
parser.
""" """
def __init__(self, def __init__(self,
lazy: bool = False, lazy: bool = False,
Expand All @@ -136,7 +139,8 @@ def __init__(self,
use_table_for_vocab: bool = False, use_table_for_vocab: bool = False,
linking_feature_extractors: List[str] = None, linking_feature_extractors: List[str] = None,
include_table_metadata: bool = False, include_table_metadata: bool = False,
max_table_tokens: int = None) -> None: max_table_tokens: int = None,
output_agendas: bool = False) -> None:
super().__init__(lazy=lazy) super().__init__(lazy=lazy)
self._tables_directory = tables_directory self._tables_directory = tables_directory
self._dpd_output_directory = dpd_output_directory self._dpd_output_directory = dpd_output_directory
Expand All @@ -152,6 +156,7 @@ def __init__(self,
self._include_table_metadata = include_table_metadata self._include_table_metadata = include_table_metadata
self._basic_types = set(str(type_) for type_ in wt_types.BASIC_TYPES) self._basic_types = set(str(type_) for type_ in wt_types.BASIC_TYPES)
self._max_table_tokens = max_table_tokens self._max_table_tokens = max_table_tokens
self._output_agendas = output_agendas


@overrides @overrides
def _read(self, file_path: str): def _read(self, file_path: str):
Expand Down Expand Up @@ -297,13 +302,12 @@ def text_to_instance(self, # type: ignore
if example_lisp_string: if example_lisp_string:
fields['example_lisp_string'] = MetadataField(example_lisp_string) fields['example_lisp_string'] = MetadataField(example_lisp_string)


# We'll make each target action sequence a List[IndexField], where the index is into
# the action list we made above. We need to ignore the type here because mypy doesn't
# like `action.rule` - it's hard to tell mypy that the ListField is made up of
# ProductionRuleFields.
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore
if dpd_output: if dpd_output:
# We'll make each target action sequence a List[IndexField], where the index is into
# the action list we made above. We need to ignore the type here because mypy doesn't
# like `action.rule` - it's hard to tell mypy that the ListField is made up of
# ProductionRuleFields.
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore

action_sequence_fields: List[Field] = [] action_sequence_fields: List[Field] = []
for logical_form in dpd_output: for logical_form in dpd_output:
if not self._should_keep_logical_form(logical_form): if not self._should_keep_logical_form(logical_form):
Expand Down Expand Up @@ -344,6 +348,13 @@ def text_to_instance(self, # type: ignore
# full test data. # full test data.
return None return None
fields['target_action_sequences'] = ListField(action_sequence_fields) fields['target_action_sequences'] = ListField(action_sequence_fields)
if self._output_agendas:
agenda_index_fields: List[Field] = []
for agenda_string in world.get_agenda():
agenda_index_fields.append(IndexField(action_map[agenda_string], action_field))
if not agenda_index_fields:
agenda_index_fields = [IndexField(-1, action_field)]
fields['agenda'] = ListField(agenda_index_fields)
return Instance(fields) return Instance(fields)


def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance: def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
Expand Down Expand Up @@ -380,16 +391,21 @@ def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
'actions': action_field, 'actions': action_field,
'example_lisp_string': example_string_field} 'example_lisp_string': example_string_field}


if 'target_action_sequences' in json_obj: if 'target_action_sequences' in json_obj or 'agenda' in json_obj:
action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore
if 'target_action_sequences' in json_obj:
action_sequence_fields: List[Field] = [] action_sequence_fields: List[Field] = []
for sequence in json_obj['target_action_sequences']: for sequence in json_obj['target_action_sequences']:
index_fields: List[Field] = [] index_fields: List[Field] = []
for production_rule in sequence: for production_rule in sequence:
index_fields.append(IndexField(action_map[production_rule], action_field)) index_fields.append(IndexField(action_map[production_rule], action_field))
action_sequence_fields.append(ListField(index_fields)) action_sequence_fields.append(ListField(index_fields))
fields['target_action_sequences'] = ListField(action_sequence_fields) fields['target_action_sequences'] = ListField(action_sequence_fields)

if 'agenda' in json_obj:
agenda_index_fields: List[Field] = []
for agenda_action in json_obj['agenda']:
agenda_index_fields.append(IndexField(action_map[agenda_action], action_field))
fields['agenda'] = ListField(agenda_index_fields)
return Instance(fields) return Instance(fields)


@staticmethod @staticmethod
Expand Down Expand Up @@ -442,6 +458,7 @@ def from_params(cls, params: Params) -> 'WikiTablesDatasetReader':
linking_feature_extracters = params.pop('linking_feature_extractors', None) linking_feature_extracters = params.pop('linking_feature_extractors', None)
include_table_metadata = params.pop_bool('include_table_metadata', False) include_table_metadata = params.pop_bool('include_table_metadata', False)
max_table_tokens = params.pop_int('max_table_tokens', None) max_table_tokens = params.pop_int('max_table_tokens', None)
output_agendas = params.pop_bool('output_agendas', False)
params.assert_empty(cls.__name__) params.assert_empty(cls.__name__)
return WikiTablesDatasetReader(lazy=lazy, return WikiTablesDatasetReader(lazy=lazy,
tables_directory=tables_directory, tables_directory=tables_directory,
Expand All @@ -456,4 +473,5 @@ def from_params(cls, params: Params) -> 'WikiTablesDatasetReader':
use_table_for_vocab=use_table_for_vocab, use_table_for_vocab=use_table_for_vocab,
linking_feature_extractors=linking_feature_extracters, linking_feature_extractors=linking_feature_extracters,
include_table_metadata=include_table_metadata, include_table_metadata=include_table_metadata,
max_table_tokens=max_table_tokens) max_table_tokens=max_table_tokens,
output_agendas=output_agendas)
5 changes: 5 additions & 0 deletions allennlp/data/tokenizers/token.py
Expand Up @@ -53,3 +53,8 @@ def __str__(self):


def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()

def __eq__(self, other):
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
Expand Up @@ -11,7 +11,7 @@
from allennlp.data.vocabulary import Vocabulary from allennlp.data.vocabulary import Vocabulary
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.modules.similarity_functions import SimilarityFunction from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.nn.decoding import DecoderTrainer from allennlp.nn.decoding import DecoderTrainer, ChecklistState
from allennlp.nn.decoding.decoder_trainers import ExpectedRiskMinimization from allennlp.nn.decoding.decoder_trainers import ExpectedRiskMinimization
from allennlp.nn import util from allennlp.nn import util
from allennlp.models.archival import load_archive, Archive from allennlp.models.archival import load_archive, Archive
Expand Down Expand Up @@ -58,7 +58,7 @@ class NlvrCoverageSemanticParser(NlvrSemanticParser):
and shouldn't be penalized, while we will mostly want to penalize longer logical forms. and shouldn't be penalized, while we will mostly want to penalize longer logical forms.
max_decoding_steps : ``int`` max_decoding_steps : ``int``
Maximum number of steps for the beam search during training. Maximum number of steps for the beam search during training.
checklist_cost_weight : ``float``, optional (default=0.8) checklist_cost_weight : ``float``, optional (default=0.6)
Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we Mixture weight (0-1) for combining coverage cost and denotation cost. As this increases, we
weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about weigh the coverage cost higher, with a value of 1.0 meaning that we do not care about
denotation accuracy. denotation accuracy.
Expand All @@ -83,7 +83,7 @@ def __init__(self,
beam_size: int, beam_size: int,
max_decoding_steps: int, max_decoding_steps: int,
normalize_beam_score_by_length: bool = False, normalize_beam_score_by_length: bool = False,
checklist_cost_weight: float = 0.8, checklist_cost_weight: float = 0.6,
dynamic_cost_weight: Dict[str, Union[int, float]] = None, dynamic_cost_weight: Dict[str, Union[int, float]] = None,
penalize_non_agenda_actions: bool = False, penalize_non_agenda_actions: bool = False,
initial_mml_model_file: str = None) -> None: initial_mml_model_file: str = None) -> None:
Expand Down Expand Up @@ -215,19 +215,17 @@ def forward(self, # type: ignore
label_strings = self._get_label_strings(labels) if labels is not None else None label_strings = self._get_label_strings(labels) if labels is not None else None
# Each instance's agenda is of size (agenda_size, 1) # Each instance's agenda is of size (agenda_size, 1)
agenda_list = [agenda[i] for i in range(batch_size)] agenda_list = [agenda[i] for i in range(batch_size)]
checklist_targets = [] initial_checklist_states = []
all_terminal_actions = []
checklist_masks = []
initial_checklist_list = []
for instance_actions, instance_agenda in zip(actions, agenda_list): for instance_actions, instance_agenda in zip(actions, agenda_list):
checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
checklist_target, terminal_actions, checklist_mask = checklist_info checklist_target, terminal_actions, checklist_mask = checklist_info
checklist_targets.append(checklist_target) initial_checklist = util.new_variable_with_size(checklist_target,
all_terminal_actions.append(terminal_actions) checklist_target.size(),
checklist_masks.append(checklist_mask) 0)
initial_checklist_list.append(util.new_variable_with_size(checklist_target, initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions,
checklist_target.size(), checklist_target=checklist_target,
0)) checklist_mask=checklist_mask,
checklist=initial_checklist))
initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)), initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
action_history=[[] for _ in range(batch_size)], action_history=[[] for _ in range(batch_size)],
score=initial_score_list, score=initial_score_list,
Expand All @@ -238,10 +236,7 @@ def forward(self, # type: ignore
possible_actions=actions, possible_actions=actions,
worlds=worlds, worlds=worlds,
label_strings=label_strings, label_strings=label_strings,
terminal_actions=all_terminal_actions, checklist_state=initial_checklist_states)
checklist_target=checklist_targets,
checklist_masks=checklist_masks,
checklist=initial_checklist_list)


agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
outputs = self._decoder_trainer.decode(initial_state, outputs = self._decoder_trainer.decode(initial_state,
Expand Down Expand Up @@ -363,21 +358,16 @@ def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor:
""" """
if not state.is_finished(): if not state.is_finished():
raise RuntimeError("_get_state_cost() is not defined for unfinished states!") raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
instance_checklist_target = state.checklist_target[0]
instance_checklist = state.checklist[0]
instance_checklist_mask = state.checklist_mask[0]

# Our checklist cost is a sum of squared error from where we want to be, making sure we # Our checklist cost is a sum of squared error from where we want to be, making sure we
# take into account the mask. # take into account the mask.
checklist_balance = instance_checklist_target - instance_checklist checklist_balance = state.checklist_state[0].get_balance()
checklist_balance = checklist_balance * instance_checklist_mask
checklist_cost = torch.sum((checklist_balance) ** 2) checklist_cost = torch.sum((checklist_balance) ** 2)


# This is the number of items on the agenda that we want to see in the decoded sequence. # This is the number of items on the agenda that we want to see in the decoded sequence.
# We use this as the denotation cost if the path is incorrect. # We use this as the denotation cost if the path is incorrect.
# Note: If we are penalizing the model for producing non-agenda actions, this is not the # Note: If we are penalizing the model for producing non-agenda actions, this is not the
# upper limit on the checklist cost. That would be the number of terminal actions. # upper limit on the checklist cost. That would be the number of terminal actions.
denotation_cost = torch.sum(instance_checklist_target.float()) denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
checklist_cost = self._checklist_cost_weight * checklist_cost checklist_cost = self._checklist_cost_weight * checklist_cost
# TODO (pradeep): The denotation based cost below is strict. May be define a cost based on # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
# how many worlds the logical form is correct in? # how many worlds the logical form is correct in?
Expand Down Expand Up @@ -435,7 +425,7 @@ def from_params(cls, vocab, params: Params) -> 'NlvrCoverageSemanticParser':
beam_size = params.pop_int('beam_size') beam_size = params.pop_int('beam_size')
normalize_beam_score_by_length = params.pop_bool('normalize_beam_score_by_length', False) normalize_beam_score_by_length = params.pop_bool('normalize_beam_score_by_length', False)
max_decoding_steps = params.pop_int("max_decoding_steps") max_decoding_steps = params.pop_int("max_decoding_steps")
checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.8) checklist_cost_weight = params.pop_float("checklist_cost_weight", 0.6)
dynamic_cost_weight = params.pop("dynamic_cost_weight", None) dynamic_cost_weight = params.pop("dynamic_cost_weight", None)
penalize_non_agenda_actions = params.pop_bool("penalize_non_agenda_actions", False) penalize_non_agenda_actions = params.pop_bool("penalize_non_agenda_actions", False)
initial_mml_model_file = params.pop("initial_mml_model_file", None) initial_mml_model_file = params.pop("initial_mml_model_file", None)
Expand Down
60 changes: 11 additions & 49 deletions allennlp/models/semantic_parsing/nlvr/nlvr_decoder_state.py
@@ -1,10 +1,8 @@
from typing import List, Dict, Tuple from typing import List, Dict, Tuple


import torch import torch
from torch.autograd import Variable

from allennlp.data.fields.production_rule_field import ProductionRuleArray from allennlp.data.fields.production_rule_field import ProductionRuleArray
from allennlp.nn.decoding import DecoderState, GrammarState, RnnState from allennlp.nn.decoding import DecoderState, GrammarState, RnnState, ChecklistState
from allennlp.semparse.worlds import NlvrWorld from allennlp.semparse.worlds import NlvrWorld




Expand Down Expand Up @@ -44,30 +42,11 @@ class NlvrDecoderState(DecoderState['NlvrDecoderState']):
String representations of labels for the elements provided. When scoring finished states, we String representations of labels for the elements provided. When scoring finished states, we
will compare the denotations of their action sequences against these labels. For each will compare the denotations of their action sequences against these labels. For each
element, there will be as many labels as there are worlds. element, there will be as many labels as there are worlds.
terminal_actions : ``List[torch.Tensor]``, optional checklist_state : ``List[ChecklistState]``, optional (default=None)
Each element in the list is a vector containing the indices of terminal actions. Currently If you are using this state within a parser being trained for coverage, we need to store a
the vectors are the same for all instances, because we consider all terminals for each ``ChecklistState`` which keeps track of the coverage information. Not needed if you are
instance. In the future, we may want to include only world-specific terminal actions here. using a non-coverage based training algorithm.
Each of these vectors is needed for computing checklists for next states, only if this state
is being while training a parser without logical forms.
checklist_target : ``List[torch.LongTensor]``, optional
List of targets corresponding to agendas that indicate the states we want the checklists to
ideally be. Each element in this list is the same size as the corresponding element in
``agenda_relevant_actions``, and it contains 1 for each corresponding action in the relevant
actions list that we want to see in the final logical form, and 0 for each corresponding
action that we do not. Needed only if this state is being used while training a parser
without logical forms.
checklist_masks : ``List[torch.Tensor]``, optional
Masks corresponding to ``terminal_actions``, indicating which of those actions are relevant
for checklist computation. For example, if the parser is penalizing non-agenda terminal
actions, all the terminal actions are relevant. Needed only if this state is being used
while training a parser without logical forms.
checklist : ``List[Variable]``, optional
A checklist for each instance indicating how many times each action in its agenda has
been chosen previously. It contains the actual counts of the agenda actions. Needed only if
this state is being used while training a parser without logical forms.
""" """
# TODO(pradeep): Group checklist related pieces into a checklist state.
def __init__(self, def __init__(self,
batch_indices: List[int], batch_indices: List[int],
action_history: List[List[int]], action_history: List[List[int]],
Expand All @@ -79,17 +58,13 @@ def __init__(self,
possible_actions: List[List[ProductionRuleArray]], possible_actions: List[List[ProductionRuleArray]],
worlds: List[List[NlvrWorld]], worlds: List[List[NlvrWorld]],
label_strings: List[List[str]], label_strings: List[List[str]],
terminal_actions: List[torch.Tensor] = None, checklist_state: List[ChecklistState] = None) -> None:
checklist_target: List[torch.Tensor] = None,
checklist_masks: List[torch.Tensor] = None,
checklist: List[Variable] = None) -> None:
super(NlvrDecoderState, self).__init__(batch_indices, action_history, score) super(NlvrDecoderState, self).__init__(batch_indices, action_history, score)
self.rnn_state = rnn_state self.rnn_state = rnn_state
self.grammar_state = grammar_state self.grammar_state = grammar_state
self.terminal_actions = terminal_actions # Converting None to list of Nones if needed, to simplify state operations.
self.checklist_target = checklist_target self.checklist_state = checklist_state if checklist_state is not None else [None for _ in
self.checklist_mask = checklist_masks batch_indices]
self.checklist = checklist
self.action_embeddings = action_embeddings self.action_embeddings = action_embeddings
self.action_indices = action_indices self.action_indices = action_indices
self.possible_actions = possible_actions self.possible_actions = possible_actions
Expand All @@ -116,17 +91,7 @@ def combine_states(cls, states) -> 'NlvrDecoderState':
scores = [score for state in states for score in state.score] scores = [score for state in states for score in state.score]
rnn_states = [rnn_state for state in states for rnn_state in state.rnn_state] rnn_states = [rnn_state for state in states for rnn_state in state.rnn_state]
grammar_states = [grammar_state for state in states for grammar_state in state.grammar_state] grammar_states = [grammar_state for state in states for grammar_state in state.grammar_state]
if states[0].terminal_actions is not None: checklist_states = [checklist_state for state in states for checklist_state in state.checklist_state]
terminal_actions = [actions for state in states for actions in state.terminal_actions]
checklist_target = [target_list for state in states for target_list in
state.checklist_target]
checklist_masks = [mask for state in states for mask in state.checklist_mask]
checklist = [checklist_list for state in states for checklist_list in state.checklist]
else:
terminal_actions = None
checklist_target = None
checklist_masks = None
checklist = None
return NlvrDecoderState(batch_indices=batch_indices, return NlvrDecoderState(batch_indices=batch_indices,
action_history=action_histories, action_history=action_histories,
score=scores, score=scores,
Expand All @@ -137,7 +102,4 @@ def combine_states(cls, states) -> 'NlvrDecoderState':
possible_actions=states[0].possible_actions, possible_actions=states[0].possible_actions,
worlds=states[0].worlds, worlds=states[0].worlds,
label_strings=states[0].label_strings, label_strings=states[0].label_strings,
terminal_actions=terminal_actions, checklist_state=checklist_states)
checklist_target=checklist_target,
checklist_masks=checklist_masks,
checklist=checklist)

0 comments on commit 8ba5867

Please sign in to comment.