Skip to content

Commit

Permalink
Merge pull request #689 from QData/at-typing-refactor
Browse files Browse the repository at this point in the history
Add typing to AttackedText class
  • Loading branch information
jxmorris12 committed Nov 2, 2022
2 parents 5043fb6 + 86e51cf commit 5ac125a
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions textattack/shared/attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
A helper class that represents a string that can be attacked.
"""

from __future__ import annotations

from collections import OrderedDict
import math
from typing import Dict, Iterable, List, Optional, Set, Tuple

import flair
from flair.data import Sentence
Expand Down Expand Up @@ -71,20 +74,21 @@ def __init__(self, text_input, attack_attrs=None):
# A list of all indices in *this* text that have been modified.
self.attack_attrs.setdefault("modified_indices", set())

def __eq__(self, other):
"""Compares two text instances to make sure they have the same attack
attributes.
def __eq__(self, other: AttackedText) -> bool:
"""Compares two AttackedText instances.
Since some elements stored in ``self.attack_attrs`` may be numpy
arrays, we have to take special care when comparing them.
Note: Does not compute true equality across attack attributes.
We found this caused large performance issues with caching,
and it's actually much faster (cache-wise) to just compare
by the text, and this works for lots of use cases.
"""
if not (self.text == other.text):
return False
if len(self.attack_attrs) != len(other.attack_attrs):
return False
return True

def __hash__(self):
def __hash__(self) -> int:
return hash(self.text)

def free_memory(self):
Expand All @@ -102,7 +106,7 @@ def free_memory(self):
if isinstance(self.attack_attrs[key], torch.Tensor):
self.attack_attrs.pop(key, None)

def text_window_around_index(self, index, window_size):
def text_window_around_index(self, index: int, window_size: int) -> str:
"""The text window of ``window_size`` words centered around
``index``."""
length = self.num_words
Expand All @@ -120,10 +124,12 @@ def text_window_around_index(self, index, window_size):
text_idx_end = self._text_index_of_word_index(end) + len(self.words[end])
return self.text[text_idx_start:text_idx_end]

def pos_of_word_index(self, desired_word_idx):
def pos_of_word_index(self, desired_word_idx: int) -> str:
"""Returns the part-of-speech of the word at index `word_idx`.
Uses FLAIR part-of-speech tagger.
Throws: ValueError, if no POS tag found for index.
"""
if not self._pos_tags:
sentence = Sentence(
Expand Down Expand Up @@ -151,10 +157,12 @@ def pos_of_word_index(self, desired_word_idx):
f"Did not find word from index {desired_word_idx} in flair POS tag"
)

def ner_of_word_index(self, desired_word_idx, model_name="ner"):
def ner_of_word_index(self, desired_word_idx: int, model_name="ner") -> str:
"""Returns the ner tag of the word at index `word_idx`.
Uses FLAIR ner tagger.
Throws: ValueError, if not NER tag found for index.
"""
if not self._ner_tags:
sentence = Sentence(
Expand All @@ -179,7 +187,7 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
f"Did not find word from index {desired_word_idx} in flair POS tag"
)

def _text_index_of_word_index(self, i):
def _text_index_of_word_index(self, i: int) -> int:
"""Returns the index of word ``i`` in self.text."""
pre_words = self.words[: i + 1]
lower_text = self.text.lower()
Expand All @@ -192,20 +200,20 @@ def _text_index_of_word_index(self, i):
look_after_index -= len(self.words[i])
return look_after_index

def text_until_word_index(self, i):
def text_until_word_index(self, i: int) -> str:
"""Returns the text before the beginning of word at index ``i``."""
look_after_index = self._text_index_of_word_index(i)
return self.text[:look_after_index]

def text_after_word_index(self, i):
def text_after_word_index(self, i: int) -> str:
"""Returns the text after the end of word at index ``i``."""
# Get index of beginning of word then jump to end of word.
look_after_index = self._text_index_of_word_index(i) + len(self.words[i])
return self.text[look_after_index:]

def first_word_diff(self, other_attacked_text):
def first_word_diff(self, other_attacked_text: AttackedText) -> Optional[str]:
"""Returns the first word in self.words that differs from
other_attacked_text.
other_attacked_text, or None if all words are the same.
Useful for word swap strategies.
"""
Expand All @@ -216,7 +224,7 @@ def first_word_diff(self, other_attacked_text):
return w1[i]
return None

def first_word_diff_index(self, other_attacked_text):
def first_word_diff_index(self, other_attacked_text: AttackedText) -> Optional[int]:
"""Returns the index of the first word in self.words that differs from
other_attacked_text.
Expand All @@ -229,7 +237,7 @@ def first_word_diff_index(self, other_attacked_text):
return i
return None

def all_words_diff(self, other_attacked_text):
def all_words_diff(self, other_attacked_text: AttackedText) -> Set[int]:
"""Returns the set of indices for which this and other_attacked_text
have different words."""
indices = set()
Expand All @@ -240,16 +248,17 @@ def all_words_diff(self, other_attacked_text):
indices.add(i)
return indices

def ith_word_diff(self, other_attacked_text, i):
"""Returns whether the word at index i differs from
def ith_word_diff(self, other_attacked_text: AttackedText, i: int) -> bool:
"""Returns bool representing whether the word at index i differs from
other_attacked_text."""
w1 = self.words
w2 = other_attacked_text.words
if len(w1) - 1 < i or len(w2) - 1 < i:
return True
return w1[i] != w2[i]

def words_diff_num(self, other_attacked_text):
def words_diff_num(self, other_attacked_text: AttackedText) -> int:
"""The number of words different between two AttackedText objects."""
# using edit distance to calculate words diff num
def generate_tokens(words):
result = {}
Expand Down Expand Up @@ -295,7 +304,7 @@ def cal_dif(w1, w2):
w2 = other_attacked_text.words
return cal_dif(w1, w2)

def convert_from_original_idxs(self, idxs):
def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]:
"""Takes indices of words from original string and converts them to
indices of the same words in the current string.
Expand All @@ -315,9 +324,16 @@ def convert_from_original_idxs(self, idxs):

return [self.attack_attrs["original_index_map"][i] for i in idxs]

def replace_words_at_indices(self, indices, new_words):
"""This code returns a new AttackedText object where the word at
``index`` is replaced with a new word."""
def get_deletion_indices(self) -> Iterable[int]:
return self.attack_attrs["original_index_map"][
self.attack_attrs["original_index_map"] == -1
]

def replace_words_at_indices(
self, indices: Iterable[int], new_words: Iterable[str]
) -> AttackedText:
"""Returns a new AttackedText object where the word at ``index`` is
replaced with a new word."""
if len(indices) != len(new_words):
raise ValueError(
f"Cannot replace {len(new_words)} words at {len(indices)} indices."
Expand All @@ -333,21 +349,21 @@ def replace_words_at_indices(self, indices, new_words):
words[i] = new_word
return self.generate_new_attacked_text(words)

def replace_word_at_index(self, index, new_word):
"""This code returns a new AttackedText object where the word at
``index`` is replaced with a new word."""
def replace_word_at_index(self, index: int, new_word: str) -> AttackedText:
"""Returns a new AttackedText object where the word at ``index`` is
replaced with a new word."""
if not isinstance(new_word, str):
raise TypeError(
f"replace_word_at_index requires ``str`` new_word, got {type(new_word)}"
)
return self.replace_words_at_indices([index], [new_word])

def delete_word_at_index(self, index):
"""This code returns a new AttackedText object where the word at
``index`` is removed."""
def delete_word_at_index(self, index: int) -> AttackedText:
"""Returns a new AttackedText object where the word at ``index`` is
removed."""
return self.replace_word_at_index(index, "")

def insert_text_after_word_index(self, index, text):
def insert_text_after_word_index(self, index: int, text: str) -> AttackedText:
"""Inserts a string before word at index ``index`` and attempts to add
appropriate spacing."""
if not isinstance(text, str):
Expand All @@ -356,7 +372,7 @@ def insert_text_after_word_index(self, index, text):
new_text = " ".join((word_at_index, text))
return self.replace_word_at_index(index, new_text)

def insert_text_before_word_index(self, index, text):
def insert_text_before_word_index(self, index: int, text: str) -> AttackedText:
"""Inserts a string before word at index ``index`` and attempts to add
appropriate spacing."""
if not isinstance(text, str):
Expand All @@ -367,12 +383,7 @@ def insert_text_before_word_index(self, index, text):
new_text = " ".join((text, word_at_index))
return self.replace_word_at_index(index, new_text)

def get_deletion_indices(self):
return self.attack_attrs["original_index_map"][
self.attack_attrs["original_index_map"] == -1
]

def generate_new_attacked_text(self, new_words):
def generate_new_attacked_text(self, new_words: Iterable[str]) -> AttackedText:
"""Returns a new AttackedText object and replaces old list of words
with a new list of words, but preserves the punctuation and spacing of
the original message.
Expand Down Expand Up @@ -466,15 +477,17 @@ def generate_new_attacked_text(self, new_words):
)
return AttackedText(perturbed_input, attack_attrs=new_attack_attrs)

def words_diff_ratio(self, x):
def words_diff_ratio(self, x: AttackedText) -> float:
"""Get the ratio of words difference between current text and `x`.
Note that current text and `x` must have same number of words.
"""
assert self.num_words == x.num_words
return float(np.sum(self.words != x.words)) / self.num_words

def align_with_model_tokens(self, model_wrapper):
def align_with_model_tokens(
self, model_wrapper: textattack.models.wrappers.ModelWrapper
) -> Dict[int, Iterable[int]]:
"""Align AttackedText's `words` with target model's tokenization scheme
(e.g. word, character, subword). Specifically, we map each word to list
of indices of tokens that compose the word (e.g. embedding --> ["em",
Expand Down Expand Up @@ -511,7 +524,7 @@ def align_with_model_tokens(self, model_wrapper):
return word2token_mapping

@property
def tokenizer_input(self):
def tokenizer_input(self) -> Tuple[str]:
"""The tuple of inputs to be passed to the tokenizer."""
input_tuple = tuple(self._text_input.values())
# Prefer to return a string instead of a tuple with a single value.
Expand All @@ -521,15 +534,15 @@ def tokenizer_input(self):
return input_tuple

@property
def column_labels(self):
def column_labels(self) -> List[str]:
"""Returns the labels for this text's columns.
For single-sequence inputs, this simply returns ['text'].
"""
return list(self._text_input.keys())

@property
def words_per_input(self):
def words_per_input(self) -> List[List[str]]:
"""Returns a list of lists of words corresponding to each input."""
if not self._words_per_input:
self._words_per_input = [
Expand All @@ -538,32 +551,32 @@ def words_per_input(self):
return self._words_per_input

@property
def words(self):
def words(self) -> List[str]:
if not self._words:
self._words = words_from_text(self.text)
return self._words

@property
def text(self):
def text(self) -> str:
"""Represents full text input.
Multiply inputs are joined with a line break.
"""
return "\n".join(self._text_input.values())

@property
def num_words(self):
def num_words(self) -> int:
"""Returns the number of words in the sequence."""
return len(self.words)

@property
def newly_swapped_words(self):
def newly_swapped_words(self) -> List[str]:
return [
self.attack_attrs["prev_attacked_text"].words[i]
for i in self.attack_attrs["newly_modified_indices"]
]

def printable_text(self, key_color="bold", key_color_method=None):
def printable_text(self, key_color="bold", key_color_method=None) -> str:
"""Represents full text input. Adds field descriptions.
For example, entailment inputs look like:
Expand Down Expand Up @@ -595,5 +608,5 @@ def ck(k):
for key, value in self._text_input.items()
)

def __repr__(self):
def __repr__(self) -> str:
return f'<AttackedText "{self.text}">'

0 comments on commit 5ac125a

Please sign in to comment.