Skip to content

Commit

Permalink
Merge pull request #644 from QData/fixes-and-improvements
Browse files Browse the repository at this point in the history
Bug fixes and tweaks
  • Loading branch information
jxmorris12 committed May 25, 2022
2 parents b433fcb + 9b20928 commit 6e7fda5
Show file tree
Hide file tree
Showing 22 changed files with 84 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install black flake8 isort # Testing packages
python setup.py install_egg_info # Workaround https://github.com/pypa/pip/issues/4537
pip install -e .[dev]
pip install black flake8 isort --upgrade # Testing packages
- name: Check code format with black and isort
run: |
make lint
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Follow these steps to start contributing:
```bash
$ cd TextAttack
$ pip install -e . ".[dev]"
$ pip install black isort pytest pytest-xdist
$ pip install black docformatter isort pytest pytest-xdist
```

This will install `textattack` in editable mode and install `black` and
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
author = "UVA QData Lab"

# The full version, including alpha/beta/rc tags
release = "0.3.4"
release = "0.3.5"

# Set master doc to `index.rst`.
master_doc = "index"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_window_around_index(self, attacked_text):

def test_big_window_around_index(self, attacked_text):
assert (
attacked_text.text_window_around_index(0, 10 ** 5) + "."
attacked_text.text_window_around_index(0, 10**5) + "."
) == attacked_text.text

def test_window_around_index_start(self, attacked_text):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_word_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_embedding_paragramcf():
word_embedding = WordEmbedding.counterfitted_GLOVE_embedding()
assert pytest.approx(word_embedding[0][0]) == -0.022007
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
assert word_embedding[10 ** 9] is None
assert word_embedding[10**9] is None


def test_embedding_gensim():
Expand All @@ -37,7 +37,7 @@ def test_embedding_gensim():
word_embedding = GensimWordEmbedding(keyed_vectors)
assert pytest.approx(word_embedding[0][0]) == 1
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
assert word_embedding[10 ** 9] is None
assert word_embedding[10**9] is None

# test query functionality
assert pytest.approx(word_embedding.get_cos_sim(1, 3)) == 0
Expand Down
11 changes: 6 additions & 5 deletions textattack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
constraints: List[Union[Constraint, PreTransformationConstraint]],
transformation: Transformation,
search_method: SearchMethod,
transformation_cache_size=2 ** 15,
constraint_cache_size=2 ** 15,
transformation_cache_size=2**15,
constraint_cache_size=2**15,
):
"""Initialize an attack object.
Expand Down Expand Up @@ -371,22 +371,23 @@ def _attack(self, initial_result):
final_result = self.search_method(initial_result)
self.clear_cache()
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
return SuccessfulAttackResult(
result = SuccessfulAttackResult(
initial_result,
final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
return FailedAttackResult(
result = FailedAttackResult(
initial_result,
final_result,
)
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
return MaximizedAttackResult(
result = MaximizedAttackResult(
initial_result,
final_result,
)
else:
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
return result

def attack(self, example, ground_truth_output):
"""Attack a single example.
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ class _CommandLineAttackArgs:
interactive: bool = False
parallel: bool = False
model_batch_size: int = 32
model_cache_size: int = 2 ** 18
constraint_cache_size: int = 2 ** 18
model_cache_size: int = 2**18
constraint_cache_size: int = 2**18

@classmethod
def _add_parser_args(cls, parser):
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

self.max_diff = max_diff
self.model_name = model_name
self._reference_score_cache = lru.LRU(2 ** 10)
self._reference_score_cache = lru.LRU(2**10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH
)

self.lm_cache = lru.LRU(2 ** 18)
self.lm_cache = lru.LRU(2**18)

def clear_cache(self):
self.lm_cache.clear()
Expand Down
5 changes: 3 additions & 2 deletions textattack/constraints/grammaticality/part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.language_nltk = language_nltk
self.language_stanza = language_stanza

self._pos_tag_cache = lru.LRU(2 ** 14)
self._pos_tag_cache = lru.LRU(2**14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
Expand Down Expand Up @@ -93,7 +93,8 @@ def _get_pos(self, before_ctx, word, after_ctx):

if self.tagger_type == "flair":
context_key_sentence = Sentence(
context_key, use_tokenizer=textattack.shared.utils.words_from_text
context_key,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
self._flair_pos_tagger.predict(context_key_sentence)
word_list, pos_list = textattack.shared.utils.zip_flair_result(
Expand Down
1 change: 1 addition & 0 deletions textattack/constraints/pre_transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .repeat_modification import RepeatModification
from .input_column_modification import InputColumnModification
from .max_word_index_modification import MaxWordIndexModification
from .max_num_words_modified import MaxNumWordsModified
from .min_word_length import MinWordLength
from .max_modification_rate import MaxModificationRate
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Max Modification Rate
-----------------------------
"""

from textattack.constraints import PreTransformationConstraint


class MaxNumWordsModified(PreTransformationConstraint):
def __init__(self, max_num_words: int):
self.max_num_words = max_num_words

def _get_modifiable_indices(self, current_text):
"""Returns the word indices in current_text which are able to be
modified."""

if len(current_text.attack_attrs["modified_indices"]) >= self.max_num_words:
return set()
else:
return set(range(len(current_text.words)))

def extra_repr_keys(self):
return ["max_num_words"]
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, embedding=None, **kwargs):
def clear_cache(self):
self._get_thought_vector.cache_clear()

@functools.lru_cache(maxsize=2 ** 10)
@functools.lru_cache(maxsize=2**10)
def _get_thought_vector(self, text):
"""Sums the embeddings of all the words in ``text`` into a "thought
vector"."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
if large:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
else:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/3"

self._tfhub_url = tfhub_url
# Lazily load the model
Expand All @@ -28,7 +28,12 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
def encode(self, sentences):
if not self.model:
self.model = hub.load(self._tfhub_url)
return self.model(sentences).numpy()
encoding = self.model(sentences)

if isinstance(encoding, dict):
encoding = encoding["outputs"]

return encoding.numpy()

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/goal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2 ** 20,
model_cache_size=2**20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model_wrapper.model.__class__
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/text/minimize_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def extra_repr_keys(self):
return ["maximizable", "target_bleu"]


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
Expand Down
4 changes: 2 additions & 2 deletions textattack/goal_functions/text/non_overlapping_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def _get_score(self, model_output, _):
return num_words_diff / len(get_words_cached(self.ground_truth_output))


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def get_words_cached(s):
return np.array(words_from_text(s))


@functools.lru_cache(maxsize=2 ** 12)
@functools.lru_cache(maxsize=2**12)
def word_difference_score(s1, s2):
"""Returns the number of words that are non-overlapping between s1 and
s2."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/metrics/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def calculate(self, results):
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.num_words_changed_until_success = np.zeros(2**16)
self.max_words_changed = 0

for i, result in enumerate(self.results):
Expand Down
9 changes: 6 additions & 3 deletions textattack/search_methods/greedy_word_swap_wir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class GreedyWordSwapWIR(SearchMethod):
model_wrapper: model wrapper used for gradient-based ranking
"""

def __init__(self, wir_method="unk"):
def __init__(self, wir_method="unk", unk_token="[UNK]"):
self.wir_method = wir_method
self.unk_token = unk_token

def _get_index_order(self, initial_text):
"""Returns word indices of ``initial_text`` in descending order of
Expand All @@ -41,15 +42,17 @@ def _get_index_order(self, initial_text):

if self.wir_method == "unk":
leave_one_texts = [
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])

elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text)
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
saliency_scores = np.array([result.score for result in leave_one_results])
Expand Down
10 changes: 8 additions & 2 deletions textattack/shared/attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def pos_of_word_index(self, desired_word_idx):
"""
if not self._pos_tags:
sentence = Sentence(
self.text, use_tokenizer=textattack.shared.utils.words_from_text
self.text,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
textattack.shared.utils.flair_tag(sentence)
self._pos_tags = sentence
Expand Down Expand Up @@ -168,7 +169,8 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"):
"""
if not self._ner_tags:
sentence = Sentence(
self.text, use_tokenizer=textattack.shared.utils.words_from_text
self.text,
use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
)
textattack.shared.utils.flair_tag(sentence, model_name)
self._ner_tags = sentence
Expand Down Expand Up @@ -467,6 +469,10 @@ def generate_new_attacked_text(self, new_words):
# Add substitute word(s) to new sentence.
perturbed_text += adv_word_seq
perturbed_text += original_text # Add all of the ending punctuation.

# Add pointer to self so chain of replacements can be reconstructed.
new_attack_attrs["prev_attacked_text"] = self

# Reform perturbed_text into an OrderedDict.
perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN)
perturbed_input = OrderedDict(
Expand Down
4 changes: 3 additions & 1 deletion textattack/shared/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import textattack

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = os.environ.get(
"TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu")
)


def html_style_from_dict(style_dict):
Expand Down
12 changes: 9 additions & 3 deletions textattack/shared/utils/strings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import string

import flair
import jieba
import pycld2 as cld2

Expand Down Expand Up @@ -57,6 +58,11 @@ def words_from_text(s, words_to_ignore=[]):
return words


class TextAttackFlairTokenizer(flair.data.Tokenizer):
def tokenize(self, text: str):
return words_from_text(text)


def default_class_repr(self):
if hasattr(self, "extra_repr_keys"):
extra_params = []
Expand All @@ -82,7 +88,7 @@ def __repr__(self):
__str__ = __repr__

def extra_repr_keys(self):
"""extra fields to be included in the representation of a class"""
"""extra fields to be included in the representation of a class."""
return []


Expand Down Expand Up @@ -203,7 +209,7 @@ def flair_tag(sentence, tag_type="upos-fast"):
from flair.models import SequenceTagger

_flair_pos_tagger = SequenceTagger.load(tag_type)
_flair_pos_tagger.predict(sentence)
_flair_pos_tagger.predict(sentence, force_token_predictions=True)


def zip_flair_result(pred, tag_type="upos-fast"):
Expand All @@ -222,7 +228,7 @@ def zip_flair_result(pred, tag_type="upos-fast"):
if "pos" in tag_type:
pos_list.append(token.annotation_layers["pos"][0]._value)
elif tag_type == "ner":
pos_list.append(token.get_tag("ner"))
pos_list.append(token.get_label("ner"))

return word_list, pos_list

Expand Down

0 comments on commit 6e7fda5

Please sign in to comment.