Skip to content

Commit

Permalink
add MaxWordsModified constraint, misc fixes and tweaks, linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jxmorris12 committed May 25, 2022
1 parent b86d0f8 commit 8e6b08d
Show file tree
Hide file tree
Showing 21 changed files with 80 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Expand Up @@ -118,7 +118,7 @@ Follow these steps to start contributing:
```bash
$ cd TextAttack
$ pip install -e . ".[dev]"
$ pip install black isort pytest pytest-xdist
$ pip install black docformatter isort pytest pytest-xdist
```

This will install `textattack` in editable mode and install `black` and
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Expand Up @@ -21,7 +21,7 @@
author = "UVA QData Lab"

# The full version, including alpha/beta/rc tags
release = "0.3.4"
release = "0.3.5"

# Set master doc to `index.rst`.
master_doc = "index"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_attacked_text.py
Expand Up @@ -61,7 +61,7 @@ def test_window_around_index(self, attacked_text):

def test_big_window_around_index(self, attacked_text):
assert (
attacked_text.text_window_around_index(0, 10 ** 5) + "."
attacked_text.text_window_around_index(0, 10**5) + "."
) == attacked_text.text

def test_window_around_index_start(self, attacked_text):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_word_embedding.py
Expand Up @@ -10,7 +10,7 @@ def test_embedding_paragramcf():
word_embedding = WordEmbedding.counterfitted_GLOVE_embedding()
assert pytest.approx(word_embedding[0][0]) == -0.022007
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
assert word_embedding[10 ** 9] is None
assert word_embedding[10**9] is None


def test_embedding_gensim():
Expand All @@ -37,7 +37,7 @@ def test_embedding_gensim():
word_embedding = GensimWordEmbedding(keyed_vectors)
assert pytest.approx(word_embedding[0][0]) == 1
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
assert word_embedding[10 ** 9] is None
assert word_embedding[10**9] is None

# test query functionality
assert pytest.approx(word_embedding.get_cos_sim(1, 3)) == 0
Expand Down
11 changes: 6 additions & 5 deletions textattack/attack.py
Expand Up @@ -81,8 +81,8 @@ def __init__(
constraints: List[Union[Constraint, PreTransformationConstraint]],
transformation: Transformation,
search_method: SearchMethod,
transformation_cache_size=2 ** 15,
constraint_cache_size=2 ** 15,
transformation_cache_size=2**15,
constraint_cache_size=2**15,
):
"""Initialize an attack object.
Expand Down Expand Up @@ -371,22 +371,23 @@ def _attack(self, initial_result):
final_result = self.search_method(initial_result)
self.clear_cache()
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(
result = SuccessfulAttackResult(
initial_result,
final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
result = FailedAttackResult(
initial_result,
final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
result = MaximizedAttackResult(
initial_result,
final_result,
)
else:
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
return result

def attack(self, example, ground_truth_output):
"""Attack a single example.
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_args.py
Expand Up @@ -478,8 +478,8 @@ class _CommandLineAttackArgs:
interactive: bool = False
parallel: bool = False
model_batch_size: int = 32
model_cache_size: int = 2 ** 18
constraint_cache_size: int = 2 ** 18
model_cache_size: int = 2**18
constraint_cache_size: int = 2**18

@classmethod
def _add_parser_args(cls, parser):
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/cola.py
Expand Up @@ -43,7 +43,7 @@ def __init__(

self.max_diff = max_diff
self.model_name = model_name
self._reference_score_cache = lru.LRU(2 ** 10)
self._reference_score_cache = lru.LRU(2**10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
Expand Down
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH
)

self.lm_cache = lru.LRU(2 ** 18)
self.lm_cache = lru.LRU(2**18)

def clear_cache(self):
self.lm_cache.clear()
Expand Down
5 changes: 3 additions & 2 deletions textattack/constraints/grammaticality/part_of_speech.py
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.language_nltk = language_nltk
self.language_stanza = language_stanza

self._pos_tag_cache = lru.LRU(2 ** 14)
self._pos_tag_cache = lru.LRU(2**14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
Expand Down Expand Up @@ -93,7 +93,8 @@ def _get_pos(self, before_ctx, word, after_ctx):

if self.tagger_type == "flair":
context_key_sentence = Sentence(
context_key, use_tokenizer=textattack.shared.utils.words_from_text
context_key,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = textattack.shared.utils.zip_flair_result(
Expand Down
1 change: 1 addition & 0 deletions textattack/constraints/pre_transformation/__init__.py
Expand Up @@ -9,5 +9,6 @@
from .repeat_modification import RepeatModification
from .input_column_modification import InputColumnModification
from .max_word_index_modification import MaxWordIndexModification
from .max_num_words_modified import MaxNumWordsModified
from .min_word_length import MinWordLength
from .max_modification_rate import MaxModificationRate
@@ -0,0 +1,25 @@
"""
Max Modification Rate
-----------------------------
"""

from textattack.constraints import PreTransformationConstraint


class MaxNumWordsModified(PreTransformationConstraint):
def __init__(self, max_num_words: int):
self.max_num_words = max_num_words

def _get_modifiable_indices(self, current_text):
"""Returns the word indices in current_text which are able to be
modified."""

if len(current_text.attack_attrs["modified_indices"]) >= self.max_num_words:
return set()
else:
return set(range(len(current_text.words)))

def extra_repr_keys(self):
return ["max_num_words"]
Expand Up @@ -32,7 +32,7 @@ def __init__(self, embedding=None, **kwargs):
def clear_cache(self):
self._get_thought_vector.cache_clear()

@functools.lru_cache(maxsize=2 ** 10)
@functools.lru_cache(maxsize=2**10)
def _get_thought_vector(self, text):
"""Sums the embeddings of all the words in ``text`` into a "thought
vector"."""
Expand Down
Expand Up @@ -19,7 +19,7 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
if large:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
else:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/3"

self._tfhub_url = tfhub_url
# Lazily load the model
Expand All @@ -28,7 +28,12 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
def encode(self, sentences):
if not self.model:
self.model = hub.load(self._tfhub_url)
return self.model(sentences).numpy()
encoding = self.model(sentences)

if isinstance(encoding, dict):
encoding = encoding["outputs"]

return encoding.numpy()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/goal_function.py
Expand Up @@ -40,7 +40,7 @@ def __init__(
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2 ** 20,
model_cache_size=2**20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model_wrapper.model.__class__
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/text/minimize_bleu.py
Expand Up @@ -59,7 +59,7 @@ def extra_repr_keys(self):
return ["maximizable", "target_bleu"]


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
Expand Down
4 changes: 2 additions & 2 deletions textattack/goal_functions/text/non_overlapping_output.py
Expand Up @@ -38,12 +38,12 @@ def _get_score(self, model_output, _):
return num_words_diff / len(get_words_cached(self.ground_truth_output))


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def get_words_cached(s):
return np.array(words_from_text(s))


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def word_difference_score(s1, s2):
"""Returns the number of words that are non-overlapping between s1 and
s2."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/metrics/attack_metrics/words_perturbed.py
Expand Up @@ -31,7 +31,7 @@ def calculate(self, results):
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.num_words_changed_until_success = np.zeros(2**16)
self.max_words_changed = 0

for i, result in enumerate(self.results):
Expand Down
9 changes: 6 additions & 3 deletions textattack/search_methods/greedy_word_swap_wir.py
Expand Up @@ -31,8 +31,9 @@ class GreedyWordSwapWIR(SearchMethod):
model_wrapper: model wrapper used for gradient-based ranking
"""

def __init__(self, wir_method="unk"):
def __init__(self, wir_method="unk", unk_token="[UNK]"):
self.wir_method = wir_method
self.unk_token = unk_token

def _get_index_order(self, initial_text):
"""Returns word indices of ``initial_text`` in descending order of
Expand All @@ -41,15 +42,17 @@ def _get_index_order(self, initial_text):

if self.wir_method == "unk":
leave_one_texts = [
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])

elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
saliency_scores = np.array([result.score for result in leave_one_results])
Expand Down
10 changes: 8 additions & 2 deletions textattack/shared/attacked_text.py
Expand Up @@ -138,7 +138,8 @@ def pos_of_word_index(self, desired_word_idx):
"""
if not self._pos_tags:
sentence = Sentence(
self.text, use_tokenizer=textattack.shared.utils.words_from_text
self.text,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
textattack.shared.utils.flair_tag(sentence)
self._pos_tags = sentence
Expand Down Expand Up @@ -168,7 +169,8 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
"""
if not self._ner_tags:
sentence = Sentence(
self.text, use_tokenizer=textattack.shared.utils.words_from_text
self.text,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
textattack.shared.utils.flair_tag(sentence, model_name)
self._ner_tags = sentence
Expand Down Expand Up @@ -467,6 +469,10 @@ def generate_new_attacked_text(self, new_words):
# Add substitute word(s) to new sentence.
perturbed_text += adv_word_seq
perturbed_text += original_text # Add all of the ending punctuation.

# Add pointer to self so chain of replacements can be reconstructed.
new_attack_attrs["prev_attacked_text"] = self

# Reform perturbed_text into an OrderedDict.
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
perturbed_input = OrderedDict(
Expand Down
4 changes: 3 additions & 1 deletion textattack/shared/utils/misc.py
Expand Up @@ -7,7 +7,9 @@

import textattack

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = os.environ.get(
"TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu")
)


def html_style_from_dict(style_dict):
Expand Down
6 changes: 6 additions & 0 deletions textattack/shared/utils/strings.py
@@ -1,5 +1,6 @@
import string

import flair
import jieba
import pycld2 as cld2

Expand Down Expand Up @@ -106,6 +107,11 @@ def words_from_text(s, words_to_ignore=[]):
return words


class TextAttackFlairTokenizer(flair.data.Tokenizer):
def tokenize(self, text: str):
return words_from_text(text)


def default_class_repr(self):
if hasattr(self, "extra_repr_keys"):
extra_params = []
Expand Down

0 comments on commit 8e6b08d

Please sign in to comment.