Skip to content
Permalink
Browse files

Output predictions in CSV format for official eval (#1249)

* output predictions in csv

* mypy fixes
  • Loading branch information...
pdasigi committed May 18, 2018
1 parent d1c36de commit eabb37fc95a9bc14f40155f01e0a4a4f0a8c8399
@@ -105,6 +105,7 @@ def _read(self, file_path: str):
continue
data = json.loads(line)
sentence = data["sentence"]
identifier = data["identifier"] if "identifier" in data else data["id"]
if "worlds" in data:
# This means that we are reading grouped nlvr data. There will be multiple
# worlds and corresponding labels per sentence.
@@ -127,7 +128,8 @@ def _read(self, file_path: str):
instance = self.text_to_instance(sentence,
structured_representations,
labels,
target_sequences)
target_sequences,
identifier)
if instance is not None:
yield instance

@@ -136,7 +138,8 @@ def text_to_instance(self, # type: ignore
sentence: str,
structured_representations: List[List[List[JsonDict]]],
labels: List[str] = None,
target_sequences: List[List[str]] = None) -> Instance:
target_sequences: List[List[str]] = None,
identifier: str = None) -> Instance:
"""
Parameters
----------
@@ -150,6 +153,8 @@ def text_to_instance(self, # type: ignore
target_sequences : ``List[List[str]]`` (optional)
List of target action sequences for each element which lead to the correct denotation in
worlds corresponding to the structured representations.
identifier : ``str`` (optional)
The identifier from the dataset if available.
"""
# pylint: disable=arguments-differ
worlds = [NlvrWorld(data) for data in structured_representations]
@@ -165,9 +170,11 @@ def text_to_instance(self, # type: ignore
production_rule_fields.append(field)
action_field = ListField(production_rule_fields)
worlds_field = ListField([MetadataField(world) for world in worlds])
fields = {"sentence": sentence_field,
"worlds": worlds_field,
"actions": action_field}
fields: Dict[str, Field] = {"sentence": sentence_field,
"worlds": worlds_field,
"actions": action_field}
if identifier is not None:
fields["identifier"] = MetadataField(identifier)
# Depending on the type of supervision used for training the parser, we may want either
# target action sequences or an agenda in our instance. We check if target sequences are
# provided, and include them if they are. If not, we'll get an agenda for the sentence, and
@@ -100,10 +100,10 @@ def __init__(self,
dropout=dropout)
self._agenda_coverage = Average()
self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] = \
ExpectedRiskMinimization(beam_size,
normalize_beam_score_by_length,
max_decoding_steps,
max_num_finished_states)
ExpectedRiskMinimization(beam_size=beam_size,
normalize_by_length=normalize_beam_score_by_length,
max_decoding_steps=max_decoding_steps,
max_num_finished_states=max_num_finished_states)

# Instantiating an empty NlvrWorld just to get the number of terminals.
self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
@@ -183,6 +183,7 @@ def forward(self, # type: ignore
worlds: List[List[NlvrWorld]],
actions: List[List[ProductionRuleArray]],
agenda: torch.LongTensor,
identifier: List[str] = None,
labels: torch.LongTensor = None,
epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
@@ -243,6 +244,8 @@ def forward(self, # type: ignore
outputs = self._decoder_trainer.decode(initial_state,
self._decoder_step,
self._get_state_cost)
if identifier is not None:
outputs['identifier'] = identifier
best_action_sequences = outputs['best_action_sequences']
batch_action_strings = self._get_action_strings(actions, best_action_sequences)
batch_denotations = self._get_denotations(batch_action_strings, worlds)
@@ -80,6 +80,7 @@ def forward(self, # type: ignore
sentence: Dict[str, torch.LongTensor],
worlds: List[List[NlvrWorld]],
actions: List[List[ProductionRuleArray]],
identifier: List[str] = None,
target_action_sequences: torch.LongTensor = None,
labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
@@ -119,6 +120,8 @@ def forward(self, # type: ignore
target_mask = None

outputs: Dict[str, torch.Tensor] = {}
if identifier is not None:
outputs["identifier"] = identifier
if target_action_sequences is not None:
outputs = self._decoder_trainer.decode(initial_state,
self._decoder_step,
@@ -1,4 +1,6 @@
from typing import Tuple
import json

from overrides import overrides

from allennlp.common.util import JsonDict
@@ -11,6 +13,23 @@ class NlvrParserPredictor(Predictor):
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]:
sentence = json_dict['sentence']
worlds = json_dict['worlds']
instance = self._dataset_reader.text_to_instance(sentence, worlds)
if 'worlds' in json_dict:
# This is grouped data
worlds = json_dict['worlds']
else:
worlds = [json_dict['structured_rep']]
identifier = json_dict['identifier'] if 'identifier' in json_dict else None
instance = self._dataset_reader.text_to_instance(sentence=sentence, # type: ignore
structured_representations=worlds,
identifier=identifier)
return instance, {}

@overrides
def dump_line(self, outputs: JsonDict) -> str: # pylint: disable=no-self-use
if "identifier" in outputs:
# Returning CSV lines for official evaluation
identifier = outputs["identifier"]
denotation = outputs["denotations"][0][0]
return f"{identifier},{denotation}\n"
else:
return json.dumps(outputs) + "\n"
@@ -11,7 +11,8 @@ def test_reader_reads_ungrouped_data(self):
instances = list(dataset)
assert len(instances) == 3
instance = instances[0]
assert instance.fields.keys() == {'sentence', 'agenda', 'worlds', 'actions', 'labels'}
assert instance.fields.keys() == {'sentence', 'agenda', 'worlds', 'actions', 'labels',
'identifier'}
sentence_tokens = instance.fields["sentence"].tokens
expected_tokens = ['There', 'is', 'a', 'circle', 'closely', 'touching', 'a', 'corner', 'of',
'a', 'box', '.']
@@ -49,7 +50,8 @@ def test_reader_reads_grouped_data(self):
instances = list(dataset)
assert len(instances) == 2
instance = instances[0]
assert instance.fields.keys() == {'sentence', 'agenda', 'worlds', 'actions', 'labels'}
assert instance.fields.keys() == {'sentence', 'agenda', 'worlds', 'actions', 'labels',
'identifier'}
sentence_tokens = instance.fields["sentence"].tokens
expected_tokens = ['There', 'is', 'a', 'circle', 'closely', 'touching', 'a', 'corner', 'of',
'a', 'box', '.']
@@ -76,7 +78,7 @@ def test_reader_reads_processed_data(self):
assert len(instances) == 2
instance = instances[0]
assert instance.fields.keys() == {"sentence", "target_action_sequences",
"worlds", "actions", "labels"}
"worlds", "actions", "labels", "identifier"}
all_action_sequence_indices = instance.fields["target_action_sequences"].field_list
assert len(all_action_sequence_indices) == 20
action_sequence_indices = [item.sequence_index for item in

0 comments on commit eabb37f

Please sign in to comment.
You can’t perform that action at this time.