Permalink
Cannot retrieve contributors at this time
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
224 lines (194 sloc)
9.28 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from typing import List, Tuple | |
import heapq | |
import numpy as np | |
import torch | |
from allennlp.common.util import JsonDict, sanitize | |
from allennlp.data import Instance | |
from allennlp.data.fields import TextField, SequenceLabelField | |
from allennlp.interpret.attackers import utils | |
from allennlp.interpret.attackers.attacker import Attacker | |
from allennlp.predictors import Predictor | |
@Attacker.register("input-reduction") | |
class InputReduction(Attacker): | |
""" | |
Runs the input reduction method from [Pathologies of Neural Models Make Interpretations | |
Difficult](https://arxiv.org/abs/1804.07781), which removes as many words as possible from | |
the input without changing the model's prediction. | |
The functions on this class handle a special case for NER by looking for a field called "tags" | |
This check is brittle, i.e., the code could break if the name of this field has changed, or if | |
a non-NER model has a field called "tags". | |
Registered as an `Attacker` with name "input-reduction". | |
""" | |
def __init__(self, predictor: Predictor, beam_size: int = 3) -> None: | |
super().__init__(predictor) | |
self.beam_size = beam_size | |
def attack_from_json( | |
self, | |
inputs: JsonDict, | |
input_field_to_attack: str = "tokens", | |
grad_input_field: str = "grad_input_1", | |
ignore_tokens: List[str] = None, | |
target: JsonDict = None, | |
): | |
if target is not None: | |
raise ValueError("Input reduction does not implement targeted attacks") | |
ignore_tokens = ["@@NULL@@"] if ignore_tokens is None else ignore_tokens | |
original_instances = self.predictor.json_to_labeled_instances(inputs) | |
original_text_field: TextField = original_instances[0][ # type: ignore | |
input_field_to_attack | |
] | |
original_tokens = deepcopy(original_text_field.tokens) | |
final_tokens = [] | |
for instance in original_instances: | |
final_tokens.append( | |
self._attack_instance( | |
inputs, instance, input_field_to_attack, grad_input_field, ignore_tokens | |
) | |
) | |
return sanitize({"final": final_tokens, "original": original_tokens}) | |
def _attack_instance( | |
self, | |
inputs: JsonDict, | |
instance: Instance, | |
input_field_to_attack: str, | |
grad_input_field: str, | |
ignore_tokens: List[str], | |
): | |
# Save fields that must be checked for equality | |
fields_to_compare = utils.get_fields_to_compare(inputs, instance, input_field_to_attack) | |
# Set num_ignore_tokens, which tells input reduction when to stop | |
# We keep at least one token for input reduction on classification/entailment/etc. | |
if "tags" not in instance: | |
num_ignore_tokens = 1 | |
tag_mask = None | |
# Set num_ignore_tokens for NER and build token mask | |
else: | |
num_ignore_tokens, tag_mask, original_tags = _get_ner_tags_and_mask( | |
instance, input_field_to_attack, ignore_tokens | |
) | |
text_field: TextField = instance[input_field_to_attack] # type: ignore | |
current_tokens = deepcopy(text_field.tokens) | |
candidates = [(instance, -1, tag_mask)] | |
# keep removing tokens until prediction is about to change | |
while len(current_tokens) > num_ignore_tokens and candidates: | |
# sort current candidates by smallest length (we want to remove as many tokens as possible) | |
def get_length(input_instance: Instance): | |
input_text_field: TextField = input_instance[input_field_to_attack] # type: ignore | |
return len(input_text_field.tokens) | |
candidates = heapq.nsmallest(self.beam_size, candidates, key=lambda x: get_length(x[0])) | |
beam_candidates = deepcopy(candidates) | |
candidates = [] | |
for beam_instance, smallest_idx, tag_mask in beam_candidates: | |
# get gradients and predictions | |
beam_tag_mask = deepcopy(tag_mask) | |
grads, outputs = self.predictor.get_gradients([beam_instance]) | |
for output in outputs: | |
if isinstance(outputs[output], torch.Tensor): | |
outputs[output] = outputs[output].detach().cpu().numpy().squeeze().squeeze() | |
elif isinstance(outputs[output], list): | |
outputs[output] = outputs[output][0] | |
# Check if any fields have changed, if so, next beam | |
if "tags" not in instance: | |
# relabel beam_instance since last iteration removed an input token | |
beam_instance = self.predictor.predictions_to_labeled_instances( | |
beam_instance, outputs | |
)[0] | |
if utils.instance_has_changed(beam_instance, fields_to_compare): | |
continue | |
# special case for sentence tagging (we have tested NER) | |
else: | |
# remove the mask where you remove the input token from. | |
if smallest_idx != -1: # Don't delete on the very first iteration | |
del beam_tag_mask[smallest_idx] # type: ignore | |
cur_tags = [ | |
outputs["tags"][x] for x in range(len(outputs["tags"])) if beam_tag_mask[x] # type: ignore | |
] | |
if cur_tags != original_tags: | |
continue | |
# remove a token from the input | |
text_field: TextField = beam_instance[input_field_to_attack] # type: ignore | |
current_tokens = deepcopy(text_field.tokens) | |
reduced_instances_and_smallest = _remove_one_token( | |
beam_instance, | |
input_field_to_attack, | |
grads[grad_input_field][0], | |
ignore_tokens, | |
self.beam_size, | |
beam_tag_mask, # type: ignore | |
) | |
candidates.extend(reduced_instances_and_smallest) | |
return current_tokens | |
def _remove_one_token( | |
instance: Instance, | |
input_field_to_attack: str, | |
grads: np.ndarray, | |
ignore_tokens: List[str], | |
beam_size: int, | |
tag_mask: List[int], | |
) -> List[Tuple[Instance, int, List[int]]]: | |
""" | |
Finds the token with the smallest gradient and removes it. | |
""" | |
# Compute L2 norm of all grads. | |
grads_mag = [np.sqrt(grad.dot(grad)) for grad in grads] | |
# Skip all ignore_tokens by setting grad to infinity | |
text_field: TextField = instance[input_field_to_attack] # type: ignore | |
for token_idx, token in enumerate(text_field.tokens): | |
if token in ignore_tokens: | |
grads_mag[token_idx] = float("inf") | |
# For NER, skip all tokens that are not in outside | |
if "tags" in instance: | |
tag_field: SequenceLabelField = instance["tags"] # type: ignore | |
labels: List[str] = tag_field.labels # type: ignore | |
for idx, label in enumerate(labels): | |
if label != "O": | |
grads_mag[idx] = float("inf") | |
reduced_instances_and_smallest: List[Tuple[Instance, int, List[int]]] = [] | |
for _ in range(beam_size): | |
# copy instance and edit later | |
copied_instance = deepcopy(instance) | |
copied_text_field: TextField = copied_instance[input_field_to_attack] # type: ignore | |
# find smallest | |
smallest = np.argmin(grads_mag) | |
if grads_mag[smallest] == float("inf"): # if all are ignored tokens, return. | |
break | |
grads_mag[smallest] = float("inf") # so the other beams don't use this token | |
# remove smallest | |
inputs_before_smallest = copied_text_field.tokens[0:smallest] | |
inputs_after_smallest = copied_text_field.tokens[smallest + 1 :] | |
copied_text_field.tokens = inputs_before_smallest + inputs_after_smallest | |
if "tags" in instance: | |
tag_field: SequenceLabelField = copied_instance["tags"] # type: ignore | |
tag_field_before_smallest = tag_field.labels[0:smallest] | |
tag_field_after_smallest = tag_field.labels[smallest + 1 :] | |
tag_field.labels = tag_field_before_smallest + tag_field_after_smallest # type: ignore | |
tag_field.sequence_field = copied_text_field | |
copied_instance.indexed = False | |
reduced_instances_and_smallest.append((copied_instance, smallest, tag_mask)) | |
return reduced_instances_and_smallest | |
def _get_ner_tags_and_mask( | |
instance: Instance, input_field_to_attack: str, ignore_tokens: List[str] | |
): | |
""" | |
Used for the NER task. Sets the num_ignore tokens, saves the original predicted tag and a 0/1 | |
mask in the position of the tags | |
""" | |
# Set num_ignore_tokens | |
num_ignore_tokens = 0 | |
input_field: TextField = instance[input_field_to_attack] # type: ignore | |
for token in input_field.tokens: | |
if str(token) in ignore_tokens: | |
num_ignore_tokens += 1 | |
# save the original tags and a 0/1 mask where the tags are | |
tag_mask = [] | |
original_tags = [] | |
tag_field: SequenceLabelField = instance["tags"] # type: ignore | |
for label in tag_field.labels: | |
if label != "O": | |
tag_mask.append(1) | |
original_tags.append(label) | |
num_ignore_tokens += 1 | |
else: | |
tag_mask.append(0) | |
return num_ignore_tokens, tag_mask, original_tags |