Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
add equality check for index field; allennlp interpret (#3073)
Browse files Browse the repository at this point in the history
* add equality check for index field; allennlp interpret

* add test

* change hotflip to use equals method

* tests per matt

* newline

* change input reduction to eq also

* undo

* add newline

* fix pylutn
  • Loading branch information
Eric-Wallace authored and matt-gardner committed Jul 18, 2019
1 parent 5014d02 commit a1476c0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 5 deletions.
4 changes: 1 addition & 3 deletions allennlp/data/fields/index_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,4 @@ def __eq__(self, other) -> bool:
# Allow equality checks to ints that are the sequence index
if isinstance(other, int):
return self.sequence_index == other
# Otherwise it has to be the same object
else:
return id(other) == id(self)
return super().__eq__(other)
5 changes: 5 additions & 0 deletions allennlp/data/tokenizers/character_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ def tokenize(self, text: str) -> List[Token]:
token = Token(text=end_token, idx=0)
tokens.append(token)
return tokens

def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
8 changes: 7 additions & 1 deletion allennlp/tests/data/fields/index_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def test_printing_doesnt_crash(self):
def test_equality(self):
index_field1 = IndexField(4, self.text)
index_field2 = IndexField(4, self.text)
index_field3 = IndexField(4, TextField([Token(t) for t in ["AllenNLP", "is", "the", "bomb", "!"]],
{"words": SingleIdTokenIndexer("words")}))

assert index_field1 == 4
assert index_field1 == index_field1
assert index_field1 != index_field2
assert index_field1 == index_field2

assert index_field1 != index_field3
assert index_field2 != index_field3
assert index_field3 == index_field3
45 changes: 44 additions & 1 deletion allennlp/tests/interpret/hotflip_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=no-self-use,invalid-name
# pylint: disable=no-self-use,invalid-name,protected-access
from allennlp.common.testing import AllenNlpTestCase
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
Expand All @@ -22,3 +22,46 @@ def test_hotflip(self):
assert 'original' in attack
assert 'outputs' in attack
assert len(attack['final'][0]) == len(attack['original']) # hotflip replaces words without removing

# test using SQuAD model (tests different equals method)
inputs = {
"question": "OMG, I heard you coded a test that succeeded on its first attempt, is that true?",
"passage": "Bro, never doubt a coding wizard! I am the king of software, MWAHAHAHA"
}

archive = load_archive(self.FIXTURES_ROOT / 'bidaf' / 'serialization' / 'model.tar.gz')
predictor = Predictor.from_archive(archive, 'machine-comprehension')

hotflipper = Hotflip(predictor)
hotflipper.initialize()
ignore_tokens = ["@@NULL@@", '.', ',', ';', '!', '?']
attack = hotflipper.attack_from_json(inputs,
'question',
'grad_input_2')
assert attack is not None
assert 'final' in attack
assert 'original' in attack
assert 'outputs' in attack
assert len(attack['final'][0]) == len(attack['original']) # hotflip replaces words without removing

instance = predictor._json_to_instance(inputs)
assert instance['question'] != attack['final'][0] # check that the input has changed.

outputs = predictor._model.forward_on_instance(instance)
original_labeled_instance = predictor.predictions_to_labeled_instances(instance, outputs)[0]
original_span_start = original_labeled_instance['span_start'].sequence_index
original_span_end = original_labeled_instance['span_end'].sequence_index

flipped_span_start = attack['outputs']['best_span'][0]
flipped_span_end = attack['outputs']['best_span'][1]

for token in instance['question']:
token = str(token)
if token in ignore_tokens:
assert token in attack['final'][0] # ignore tokens should not be changed
# HotFlip keeps changing tokens until either the predictions changes or all tokens have
# been changed. If there are tokens in the HotFlip final result that were in the original
# (i.e., not all tokens were flipped), then the prediction should be different.
else:
if token in attack['final'][0]:
assert original_span_start != flipped_span_start or original_span_end != flipped_span_end

0 comments on commit a1476c0

Please sign in to comment.