Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Minor WTQ ERM model and dataset reader fixes for demo (#3068)
Browse files Browse the repository at this point in the history
* model dataset reader fixes for demo

* fixed a typo

* mypy ignore

* make mypy happy
  • Loading branch information
pdasigi committed Jul 16, 2019
1 parent ec30c90 commit 7728b12
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def text_to_instance(self, # type: ignore
question_field = TextField(tokenized_question, self._question_token_indexers)
metadata: Dict[str, Any] = {"question_tokens": [x.text for x in tokenized_question]}
table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question)
target_values_field = MetadataField(target_values)
world = WikiTablesLanguage(table_context)
world_field = MetadataField(world)
# Note: Not passing any featre extractors when instantiating the field below. This will make
Expand All @@ -233,8 +232,11 @@ def text_to_instance(self, # type: ignore
'metadata': MetadataField(metadata),
'table': table_field,
'world': world_field,
'actions': action_field,
'target_values': target_values_field}
'actions': action_field}

if target_values is not None:
target_values_field = MetadataField(target_values)
fields['target_values'] = target_values_field

# We'll make each target action sequence a List[IndexField], where the index is into
# the action list we made above. We need to ignore the type here because mypy doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from allennlp.models.model import Model
from allennlp.models.semantic_parsing.wikitables.wikitables_semantic_parser import WikiTablesSemanticParser
from allennlp.modules import Attention, FeedForward, Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.state_machines import BeamSearch
from allennlp.state_machines.states import CoverageState, ChecklistStatelet
from allennlp.state_machines.trainers import ExpectedRiskMinimization
from allennlp.state_machines.transition_functions import LinkingCoverageTransitionFunction
Expand Down Expand Up @@ -127,6 +128,10 @@ def __init__(self,
dropout=dropout)
self._checklist_cost_weight = checklist_cost_weight
self._agenda_coverage = Average()
# We don't need a separate beam search since the trainer does that already. But we're defining one just to
# be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch``
# class) in the demo. We'll use this only at test time.
self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size)
# TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
# copied a trained ERM model from a different machine and the original MML model that was
# used to initialize it does not exist on the current machine. This may not be the best
Expand Down Expand Up @@ -262,21 +267,24 @@ def forward(self, # type: ignore
extras=target_values,
debug_info=None)

if not self.training:
if target_values is not None:
logger.warning(f"TARGET VALUES: {target_values}")
trainer_outputs = self._decoder_trainer.decode(initial_state, # type: ignore
self._decoder_step,
partial(self._get_state_cost, world))
outputs.update(trainer_outputs)
else:
initial_state.debug_info = [[] for _ in range(batch_size)]

outputs = self._decoder_trainer.decode(initial_state, # type: ignore
self._decoder_step,
partial(self._get_state_cost, world))
best_final_states = outputs['best_final_states']

if not self.training:
batch_size = len(actions)
agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
action_mapping = {}
for batch_index, batch_actions in enumerate(actions):
for action_index, action in enumerate(batch_actions):
action_mapping[(batch_index, action_index)] = action[0]
best_final_states = self._beam_search.search(self._max_decoding_steps,
initial_state,
self._decoder_step,
keep_final_unfinished_states=False)
for i in range(batch_size):
in_agenda_ratio = 0.0
# Decoding may not have terminated with any completed logical forms, if `num_steps`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def _compute_validation_outputs(self,
has_logical_form = True
except ParsingError:
logical_form = 'Error producing logical form'
if target_list[0] is not None:
if target_list is not None:
denotation_correct = world[i].evaluate_logical_form(logical_form, target_list[i])
else:
denotation_correct = False
Expand Down

0 comments on commit 7728b12

Please sign in to comment.