Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Pass instance metadata through to forward, use it to compute official BiDAF metrics #216

Merged
merged 3 commits into from Aug 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions allennlp/commands/evaluate.py
@@ -1,3 +1,4 @@
from inspect import signature
from typing import Dict, Any
import argparse
import logging
Expand Down Expand Up @@ -44,6 +45,8 @@ def evaluate(model: Model,
logger.info("Iterating over dataset")
for batch in tqdm.tqdm(generator, total=iterator.get_num_batches(dataset)):
tensor_batch = arrays_to_variables(batch, cuda_device, for_training=False)
if 'metadata' in tensor_batch and 'metadata' not in signature(model.forward).parameters:
del tensor_batch['metadata']
model.forward(**tensor_batch)

return model.get_metrics()
Expand Down
26 changes: 3 additions & 23 deletions scripts/squad_eval.py → allennlp/common/squad_eval.py
@@ -1,17 +1,14 @@
""" Official evaluation script for v1.1 of the SQuAD dataset. """
# pylint: skip-file
from __future__ import print_function
from collections import Counter
import editdistance
import string
import re
import argparse
import json
import sys


verbosity = 0


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
Expand Down Expand Up @@ -44,24 +41,7 @@ def f1_score(prediction, ground_truth):


def exact_match_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
correct = normalized_prediction == normalized_ground_truth
if verbosity > 0 and not correct:
pred_no_whitespace = normalized_prediction.replace(' ', '')
truth_no_whitespace = normalized_ground_truth.replace(' ', '')
if pred_no_whitespace == truth_no_whitespace:
print("Prediction and truth differ only in whitespace!")
print("Normalized ground truth:", normalized_ground_truth)
print("Normalized prediction:", normalized_prediction)
print()
if verbosity > 1:
if editdistance.eval(normalized_prediction, normalized_ground_truth) < 5:
print("Small edit distance between truth and prediction; could be a tokenization error")
print("Normalized ground truth:", normalized_ground_truth)
print("Normalized prediction:", normalized_prediction)
print()
return correct
return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
Expand All @@ -81,7 +61,7 @@ def evaluate(dataset, predictions):
if qa['id'] not in predictions:
message = 'Unanswered question ' + qa['id'] + \
' will receive score 0.'
#print(message, file=sys.stderr)
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x['text'], qa['answers']))
prediction = predictions[qa['id']]
Expand Down
19 changes: 15 additions & 4 deletions allennlp/common/testing/model_test_case.py
@@ -1,6 +1,8 @@
import os
from inspect import signature

from numpy.testing import assert_allclose
import torch

from allennlp.commands.train import train_model_from_file
from allennlp.common import Params
Expand Down Expand Up @@ -56,6 +58,9 @@ def ensure_model_can_train_save_and_load(self, param_file: str):

# The datasets themselves should be identical.
for key in model_batch.keys():
if key == 'metadata':
assert model_batch[key] == loaded_batch[key]
continue
field = model_batch[key]
if isinstance(field, dict):
for subfield in field:
Expand All @@ -72,6 +77,9 @@ def ensure_model_can_train_save_and_load(self, param_file: str):
# Set eval mode, to turn off things like dropout, then get predictions.
model.eval()
loaded_model.eval()
if 'metadata' in model_batch and 'metadata' not in signature(model.forward).parameters:
del model_batch['metadata']
del loaded_batch['metadata']
model_predictions = model.forward(**model_batch)
loaded_model_predictions = loaded_model.forward(**loaded_batch)

Expand All @@ -82,9 +90,12 @@ def ensure_model_can_train_save_and_load(self, param_file: str):

# Both outputs should have the same keys and the values for these keys should be close.
for key in model_predictions.keys():
assert_allclose(model_predictions[key].data.numpy(),
loaded_model_predictions[key].data.numpy(),
rtol=1e-5,
err_msg=key)
if isinstance(model_predictions[key], torch.autograd.Variable):
assert_allclose(model_predictions[key].data.numpy(),
loaded_model_predictions[key].data.numpy(),
rtol=1e-5,
err_msg=key)
else:
assert model_predictions[key] == loaded_model_predictions[key]

return model, loaded_model
12 changes: 5 additions & 7 deletions allennlp/data/dataset.py
Expand Up @@ -140,17 +140,15 @@ def as_array_dict(self,
field_arrays = defaultdict(list) # type: Dict[str, list]
if verbose:
logger.info("Now actually padding instances to length: %s", str(lengths_to_use))
for instance in tqdm.tqdm(self.instances):
for field, arrays in instance.as_array_dict(lengths_to_use).items():
field_arrays[field].append(arrays)
else:
for instance in self.instances:
for field, arrays in instance.as_array_dict(lengths_to_use).items():
field_arrays[field].append(arrays)
for instance in self.instances:
for field, arrays in instance.as_array_dict(lengths_to_use).items():
field_arrays[field].append(arrays)

# Finally, we combine the arrays that we got for each instance into one big array (or set
# of arrays) per field.
for field_name, field_array_list in field_arrays.items():
if field_name == 'metadata':
continue
if isinstance(field_array_list[0], dict):
# This is creating a dict of {token_indexer_key: batch_array} for each
# token indexer used to index this field. This is mostly utilised by TextFields.
Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/language_modeling.py
Expand Up @@ -62,14 +62,14 @@ def read(self, file_path: str):
instance_strings = text_file.readlines()
if self._tokens_per_instance is not None:
all_text = " ".join([x.replace("\n", " ").strip() for x in instance_strings])
tokenized_text = self._tokenizer.tokenize(all_text)
tokenized_text, _ = self._tokenizer.tokenize(all_text)
num_tokens = self._tokens_per_instance
tokenized_strings = []
logger.info("Creating dataset from all text in file: %s", file_path)
for index in tqdm.tqdm(range(0, len(tokenized_text) - num_tokens, num_tokens)):
tokenized_strings.append(tokenized_text[index:index + num_tokens])
else:
tokenized_strings = [self._tokenizer.tokenize(s) for s in instance_strings]
tokenized_strings = [self._tokenizer.tokenize(s)[0] for s in instance_strings]

# TODO(matt): this isn't quite right, because you really want to split on sentences,
# tokenize the sentences, add the start and end tokens per sentence, then change the tokens
Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/snli.py
Expand Up @@ -68,9 +68,9 @@ def read(self, file_path: str):
label_field = LabelField(label)

premise = example["sentence1"]
premise_tokens = self._tokenizer.tokenize(premise)
premise_tokens, _ = self._tokenizer.tokenize(premise)
hypothesis = example["sentence2"]
hypothesis_tokens = self._tokenizer.tokenize(hypothesis)
hypothesis_tokens, _ = self._tokenizer.tokenize(hypothesis)
if self._append_null:
premise_tokens.append(self.null_token)
hypothesis_tokens.append(self.null_token)
Expand Down
33 changes: 20 additions & 13 deletions allennlp/data/dataset_readers/squad.py
Expand Up @@ -46,7 +46,7 @@ def _char_span_to_token_span(sentence: str,
# First we'll tokenize the span and the sentence, so we can count tokens and check for
# matches.
span_chars = sentence[span[0]:span[1]]
tokenized_span = tokenizer.tokenize(span_chars)
tokenized_span, _ = tokenizer.tokenize(span_chars)
# Then we'll find what we think is the first token in the span
chars_seen = 0
index = 0
Expand Down Expand Up @@ -85,9 +85,10 @@ class SquadReader(DatasetReader):
fields: ``question``, a ``TextField``, ``passage``, another ``TextField``, and ``span_start``
and ``span_end``, both ``IndexFields`` into the ``passage`` ``TextField``.

The ``Instances`` also store their ID and the original passage text in the instance metadata,
accessible as ``instance.metadata['id']`` and ``instance.metadata['original_passage']``. You
will want this if you want to use the official SQuAD evaluation script.
The ``Instances`` also store their ID, the original passage text, and token offsets into the
original passage in the instance metadata, accessible as ``instance.metadata['id']``,
``instance.metadata['original_passage']``, and ``instance.metadata['token_offsets']``. This is
so that we can more easily use the official SQuAD evaluation script to get metrics.

Parameters
----------
Expand Down Expand Up @@ -127,12 +128,14 @@ def read(self, file_path: str):
# labels are end-exclusive, and we do a softmax over the passage to determine span
# end. So if we want to be able to include the last token of the passage, we need
# to have a special symbol at the end.
tokenized_paragraph = self._tokenizer.tokenize(cleaned_paragraph) + [self.STOP_TOKEN]
paragraph_tokens, paragraph_offsets = self._tokenizer.tokenize(cleaned_paragraph)
paragraph_tokens.append(self.STOP_TOKEN)
paragraph_offsets.append((-1, -1))

for question_answer in paragraph_json['qas']:
question_text = question_answer["question"].strip().replace("\n", "")
question_id = question_answer['id'].strip()
tokenized_question = self._tokenizer.tokenize(question_text)
question_tokens, _ = self._tokenizer.tokenize(question_text)

# There may be multiple answer annotations, so pick the one that occurs the
# most.
Expand All @@ -145,7 +148,7 @@ def read(self, file_path: str):
# we need a token index for our models. We convert them here.
char_span_end = char_span_start + len(answer_text)
span_start, span_end = _char_span_to_token_span(paragraph,
tokenized_paragraph,
paragraph_tokens,
(char_span_start, char_span_end),
self._tokenizer)

Expand All @@ -154,8 +157,8 @@ def read(self, file_path: str):
# when indexing is done, and when padding is done). I _think_ all of those
# operations would be safe with shared objects, but I'd rather just be safe by
# doing a copy here. Extra memory usage should be minimal.
paragraph_field = TextField(deepcopy(tokenized_paragraph), self._token_indexers)
question_field = TextField(tokenized_question, self._token_indexers)
paragraph_field = TextField(deepcopy(paragraph_tokens), self._token_indexers)
question_field = TextField(question_tokens, self._token_indexers)
span_start_field = IndexField(span_start, paragraph_field)
span_end_field = IndexField(span_end, paragraph_field)
fields = {
Expand All @@ -164,7 +167,11 @@ def read(self, file_path: str):
'span_start': span_start_field,
'span_end': span_end_field
}
metadata = {'id': question_id, 'original_passage': paragraph}
metadata = {
'id': question_id,
'original_passage': paragraph,
'token_offsets': paragraph_offsets
}
instance = Instance(fields, metadata)
instances.append(instance)
if not instances:
Expand Down Expand Up @@ -381,12 +388,12 @@ def read(self, file_path: str):
question_text = self._id_to_question[question_id]
sentence_fields = [] # type: List[Field]
for sentence in sentence_choices:
tokenized_sentence = self._tokenizer.tokenize(sentence)
tokenized_sentence, _ = self._tokenizer.tokenize(sentence)
sentence_field = TextField(tokenized_sentence, self._token_indexers)
sentence_fields.append(sentence_field)
sentences_field = ListField(sentence_fields)
tokenized_question = self._tokenizer.tokenize(question_text)
question_field = TextField(tokenized_question, self._token_indexers)
question_tokens, _ = self._tokenizer.tokenize(question_text)
question_field = TextField(question_tokens, self._token_indexers)
correct_sentence_field = IndexField(correct_choice, sentences_field)
instances.append(Instance({'question': question_field,
'sentences': sentences_field,
Expand Down
5 changes: 5 additions & 0 deletions allennlp/data/instance.py
Expand Up @@ -66,9 +66,14 @@ def as_array_dict(self, padding_lengths: Dict[str, Dict[str, int]] = None) -> Di

If ``padding_lengths`` is omitted, we will call ``self.get_padding_lengths()`` to get the
sizes of the arrays to create.

In the array dictionary, we also pass along the instance metadata, if any is given. This
is contained in the ``'metadata'`` key.
"""
padding_lengths = padding_lengths or self.get_padding_lengths()
arrays = {}
for field_name, field in self.fields.items():
arrays[field_name] = field.as_array(padding_lengths[field_name])
if self.metadata:
arrays['metadata'] = self.metadata
return arrays
4 changes: 2 additions & 2 deletions allennlp/data/token_indexers/token_characters_indexer.py
Expand Up @@ -35,13 +35,13 @@ def __init__(self,

@overrides
def count_vocab_items(self, token: str, counter: Dict[str, Dict[str, int]]):
for character in self.character_tokenizer.tokenize(token):
for character in self.character_tokenizer.tokenize(token)[0]:
counter[self.namespace][character] += 1

@overrides
def token_to_indices(self, token: str, vocabulary: Vocabulary) -> List[int]:
indices = []
for character in self.character_tokenizer.tokenize(token):
for character in self.character_tokenizer.tokenize(token)[0]:
indices.append(vocabulary.get_token_index(character, self.namespace))
return indices

Expand Down
8 changes: 4 additions & 4 deletions allennlp/data/tokenizers/character_tokenizer.py
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

from overrides import overrides

Expand Down Expand Up @@ -32,12 +32,12 @@ def __init__(self, byte_encoding: str = None, lowercase_characters: bool = False
self.lowercase_characters = lowercase_characters

@overrides
def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> Tuple[List[str], List[Tuple[int, int]]]:
if self.lowercase_characters:
text = text.lower()
if self.byte_encoding is not None:
return [chr(x) for x in text.encode(self.byte_encoding)]
return list(text)
return [chr(x) for x in text.encode(self.byte_encoding)], None
return list(text), None

@classmethod
def from_params(cls, params: Params) -> 'CharacterTokenizer':
Expand Down
12 changes: 10 additions & 2 deletions allennlp/data/tokenizers/tokenizer.py
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

from allennlp.common import Params, Registrable

Expand All @@ -22,9 +22,17 @@ class Tokenizer(Registrable):
"""
default_implementation = 'word'

def tokenize(self, text: str) -> List[str]:
def tokenize(self, text: str) -> Tuple[List[str], List[Tuple[int, int]]]:
"""
The only public method for this class. Actually implements splitting words into tokens.

Returns
-------
tokens : ``List[str]``
offsets : ``List[Tuple[int, int]]``
A list of the same lengths as ``tokens``, giving character offsets into the original
string for each token. Not all tokenizers implement this, so this value could be
``None``.
"""
raise NotImplementedError

Expand Down
16 changes: 10 additions & 6 deletions allennlp/data/tokenizers/word_filter.py
Expand Up @@ -15,8 +15,12 @@ class WordFilter(Registrable):
"""
default_implementation = 'pass_through'

def filter_words(self, words: List[str]) -> List[str]:
"""Filters words from the given word list"""
def should_keep_words(self, words: List[str]) -> List[bool]:
"""
Decides whether to remove words from the given list. To make it easier to deal with data
associated with the word list (like character offsets), we return a list of boolean
decisions for each word, which the caller can process to actually filter the list.
"""
raise NotImplementedError

@classmethod
Expand All @@ -32,8 +36,8 @@ class PassThroughWordFilter(WordFilter):
Does not filter words; it's a no-op. This is the default word filter.
"""
@overrides
def filter_words(self, words: List[str]) -> List[str]:
return words
def should_keep_words(self, words: List[str]) -> List[bool]:
return [True] * len(words)


@WordFilter.register('stopwords')
Expand Down Expand Up @@ -69,5 +73,5 @@ def __init__(self):
"'", '"', '&', '$', '#', '@', '(', ')', '?'])

@overrides
def filter_words(self, words: List[str]) -> List[str]:
return [word for word in words if word not in self.stopwords]
def should_keep_words(self, words: List[str]) -> List[bool]:
return [word not in self.stopwords for word in words]