From 8e6b08db72db49e827505aaceb6fa9415ab68546 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Wed, 25 May 2022 13:39:30 -0400 Subject: [PATCH] add MaxWordsModified constraint, misc fixes and tweaks, linting --- CONTRIBUTING.md | 2 +- docs/conf.py | 2 +- tests/test_attacked_text.py | 2 +- tests/test_word_embedding.py | 4 +-- textattack/attack.py | 11 ++++---- textattack/attack_args.py | 4 +-- textattack/constraints/grammaticality/cola.py | 2 +- .../google_language_model/alzantot_goog_lm.py | 2 +- .../grammaticality/part_of_speech.py | 5 ++-- .../pre_transformation/__init__.py | 1 + .../max_num_words_modified.py | 25 +++++++++++++++++++ .../sentence_encoders/thought_vector.py | 2 +- .../universal_sentence_encoder.py | 9 +++++-- textattack/goal_functions/goal_function.py | 2 +- .../goal_functions/text/minimize_bleu.py | 2 +- .../text/non_overlapping_output.py | 4 +-- .../metrics/attack_metrics/words_perturbed.py | 2 +- .../search_methods/greedy_word_swap_wir.py | 9 ++++--- textattack/shared/attacked_text.py | 10 ++++++-- textattack/shared/utils/misc.py | 4 ++- textattack/shared/utils/strings.py | 6 +++++ 21 files changed, 80 insertions(+), 30 deletions(-) create mode 100644 textattack/constraints/pre_transformation/max_num_words_modified.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9a5bffe1..398884a1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -118,7 +118,7 @@ Follow these steps to start contributing: ```bash $ cd TextAttack $ pip install -e . ".[dev]" - $ pip install black isort pytest pytest-xdist + $ pip install black docformatter isort pytest pytest-xdist ``` This will install `textattack` in editable mode and install `black` and diff --git a/docs/conf.py b/docs/conf.py index 590c78f7..31c2a1bb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ author = "UVA QData Lab" # The full version, including alpha/beta/rc tags -release = "0.3.4" +release = "0.3.5" # Set master doc to `index.rst`. master_doc = "index" diff --git a/tests/test_attacked_text.py b/tests/test_attacked_text.py index f62988a4..698a89ed 100644 --- a/tests/test_attacked_text.py +++ b/tests/test_attacked_text.py @@ -61,7 +61,7 @@ def test_window_around_index(self, attacked_text): def test_big_window_around_index(self, attacked_text): assert ( - attacked_text.text_window_around_index(0, 10 ** 5) + "." + attacked_text.text_window_around_index(0, 10**5) + "." ) == attacked_text.text def test_window_around_index_start(self, attacked_text): diff --git a/tests/test_word_embedding.py b/tests/test_word_embedding.py index 5232e8fa..4772c27d 100644 --- a/tests/test_word_embedding.py +++ b/tests/test_word_embedding.py @@ -10,7 +10,7 @@ def test_embedding_paragramcf(): word_embedding = WordEmbedding.counterfitted_GLOVE_embedding() assert pytest.approx(word_embedding[0][0]) == -0.022007 assert pytest.approx(word_embedding["fawn"][0]) == -0.022007 - assert word_embedding[10 ** 9] is None + assert word_embedding[10**9] is None def test_embedding_gensim(): @@ -37,7 +37,7 @@ def test_embedding_gensim(): word_embedding = GensimWordEmbedding(keyed_vectors) assert pytest.approx(word_embedding[0][0]) == 1 assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2) - assert word_embedding[10 ** 9] is None + assert word_embedding[10**9] is None # test query functionality assert pytest.approx(word_embedding.get_cos_sim(1, 3)) == 0 diff --git a/textattack/attack.py b/textattack/attack.py index a7a072b2..f835a248 100644 --- a/textattack/attack.py +++ b/textattack/attack.py @@ -81,8 +81,8 @@ def __init__( constraints: List[Union[Constraint, PreTransformationConstraint]], transformation: Transformation, search_method: SearchMethod, - transformation_cache_size=2 ** 15, - constraint_cache_size=2 ** 15, + transformation_cache_size=2**15, + constraint_cache_size=2**15, ): """Initialize an attack object. @@ -371,22 +371,23 @@ def _attack(self, initial_result): final_result = self.search_method(initial_result) self.clear_cache() if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: - return SuccessfulAttackResult( + result = SuccessfulAttackResult( initial_result, final_result, ) elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING: - return FailedAttackResult( + result = FailedAttackResult( initial_result, final_result, ) elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING: - return MaximizedAttackResult( + result = MaximizedAttackResult( initial_result, final_result, ) else: raise ValueError(f"Unrecognized goal status {final_result.goal_status}") + return result def attack(self, example, ground_truth_output): """Attack a single example. diff --git a/textattack/attack_args.py b/textattack/attack_args.py index c3724141..3bf706f8 100644 --- a/textattack/attack_args.py +++ b/textattack/attack_args.py @@ -478,8 +478,8 @@ class _CommandLineAttackArgs: interactive: bool = False parallel: bool = False model_batch_size: int = 32 - model_cache_size: int = 2 ** 18 - constraint_cache_size: int = 2 ** 18 + model_cache_size: int = 2**18 + constraint_cache_size: int = 2**18 @classmethod def _add_parser_args(cls, parser): diff --git a/textattack/constraints/grammaticality/cola.py b/textattack/constraints/grammaticality/cola.py index beb7c30a..190bad25 100644 --- a/textattack/constraints/grammaticality/cola.py +++ b/textattack/constraints/grammaticality/cola.py @@ -43,7 +43,7 @@ def __init__( self.max_diff = max_diff self.model_name = model_name - self._reference_score_cache = lru.LRU(2 ** 10) + self._reference_score_cache = lru.LRU(2**10) model = AutoModelForSequenceClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = HuggingFaceModelWrapper(model, tokenizer) diff --git a/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py b/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py index d47bfd6e..005dda55 100644 --- a/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py +++ b/textattack/constraints/grammaticality/language_models/google_language_model/alzantot_goog_lm.py @@ -49,7 +49,7 @@ def __init__(self): self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH ) - self.lm_cache = lru.LRU(2 ** 18) + self.lm_cache = lru.LRU(2**18) def clear_cache(self): self.lm_cache.clear() diff --git a/textattack/constraints/grammaticality/part_of_speech.py b/textattack/constraints/grammaticality/part_of_speech.py index 19388ed0..f531f33c 100644 --- a/textattack/constraints/grammaticality/part_of_speech.py +++ b/textattack/constraints/grammaticality/part_of_speech.py @@ -56,7 +56,7 @@ def __init__( self.language_nltk = language_nltk self.language_stanza = language_stanza - self._pos_tag_cache = lru.LRU(2 ** 14) + self._pos_tag_cache = lru.LRU(2**14) if tagger_type == "flair": if tagset == "universal": self._flair_pos_tagger = SequenceTagger.load("upos-fast") @@ -93,7 +93,8 @@ def _get_pos(self, before_ctx, word, after_ctx): if self.tagger_type == "flair": context_key_sentence = Sentence( - context_key, use_tokenizer=textattack.shared.utils.words_from_text + context_key, + use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) self._flair_pos_tagger.predict(context_key_sentence) word_list, pos_list = textattack.shared.utils.zip_flair_result( diff --git a/textattack/constraints/pre_transformation/__init__.py b/textattack/constraints/pre_transformation/__init__.py index 73131a9e..9cfdd8cf 100644 --- a/textattack/constraints/pre_transformation/__init__.py +++ b/textattack/constraints/pre_transformation/__init__.py @@ -9,5 +9,6 @@ from .repeat_modification import RepeatModification from .input_column_modification import InputColumnModification from .max_word_index_modification import MaxWordIndexModification +from .max_num_words_modified import MaxNumWordsModified from .min_word_length import MinWordLength from .max_modification_rate import MaxModificationRate diff --git a/textattack/constraints/pre_transformation/max_num_words_modified.py b/textattack/constraints/pre_transformation/max_num_words_modified.py new file mode 100644 index 00000000..3d851050 --- /dev/null +++ b/textattack/constraints/pre_transformation/max_num_words_modified.py @@ -0,0 +1,25 @@ +""" + +Max Modification Rate +----------------------------- + +""" + +from textattack.constraints import PreTransformationConstraint + + +class MaxNumWordsModified(PreTransformationConstraint): + def __init__(self, max_num_words: int): + self.max_num_words = max_num_words + + def _get_modifiable_indices(self, current_text): + """Returns the word indices in current_text which are able to be + modified.""" + + if len(current_text.attack_attrs["modified_indices"]) >= self.max_num_words: + return set() + else: + return set(range(len(current_text.words))) + + def extra_repr_keys(self): + return ["max_num_words"] diff --git a/textattack/constraints/semantics/sentence_encoders/thought_vector.py b/textattack/constraints/semantics/sentence_encoders/thought_vector.py index 60bac23b..4a7978b0 100644 --- a/textattack/constraints/semantics/sentence_encoders/thought_vector.py +++ b/textattack/constraints/semantics/sentence_encoders/thought_vector.py @@ -32,7 +32,7 @@ def __init__(self, embedding=None, **kwargs): def clear_cache(self): self._get_thought_vector.cache_clear() - @functools.lru_cache(maxsize=2 ** 10) + @functools.lru_cache(maxsize=2**10) def _get_thought_vector(self, text): """Sums the embeddings of all the words in ``text`` into a "thought vector".""" diff --git a/textattack/constraints/semantics/sentence_encoders/universal_sentence_encoder/universal_sentence_encoder.py b/textattack/constraints/semantics/sentence_encoders/universal_sentence_encoder/universal_sentence_encoder.py index 68731310..c4017a5f 100644 --- a/textattack/constraints/semantics/sentence_encoders/universal_sentence_encoder/universal_sentence_encoder.py +++ b/textattack/constraints/semantics/sentence_encoders/universal_sentence_encoder/universal_sentence_encoder.py @@ -19,7 +19,7 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs): if large: tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5" else: - tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/4" + tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/3" self._tfhub_url = tfhub_url # Lazily load the model @@ -28,7 +28,12 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs): def encode(self, sentences): if not self.model: self.model = hub.load(self._tfhub_url) - return self.model(sentences).numpy() + encoding = self.model(sentences) + + if isinstance(encoding, dict): + encoding = encoding["outputs"] + + return encoding.numpy() def __getstate__(self): state = self.__dict__.copy() diff --git a/textattack/goal_functions/goal_function.py b/textattack/goal_functions/goal_function.py index e2d4b11a..665f9b78 100644 --- a/textattack/goal_functions/goal_function.py +++ b/textattack/goal_functions/goal_function.py @@ -40,7 +40,7 @@ def __init__( use_cache=True, query_budget=float("inf"), model_batch_size=32, - model_cache_size=2 ** 20, + model_cache_size=2**20, ): validators.validate_model_goal_function_compatibility( self.__class__, model_wrapper.model.__class__ diff --git a/textattack/goal_functions/text/minimize_bleu.py b/textattack/goal_functions/text/minimize_bleu.py index 33999577..92613be5 100644 --- a/textattack/goal_functions/text/minimize_bleu.py +++ b/textattack/goal_functions/text/minimize_bleu.py @@ -59,7 +59,7 @@ def extra_repr_keys(self): return ["maximizable", "target_bleu"] -@functools.lru_cache(maxsize=2 ** 12) +@functools.lru_cache(maxsize=2**12) def get_bleu(a, b): ref = a.words hyp = b.words diff --git a/textattack/goal_functions/text/non_overlapping_output.py b/textattack/goal_functions/text/non_overlapping_output.py index 443aa236..e2cb4982 100644 --- a/textattack/goal_functions/text/non_overlapping_output.py +++ b/textattack/goal_functions/text/non_overlapping_output.py @@ -38,12 +38,12 @@ def _get_score(self, model_output, _): return num_words_diff / len(get_words_cached(self.ground_truth_output)) -@functools.lru_cache(maxsize=2 ** 12) +@functools.lru_cache(maxsize=2**12) def get_words_cached(s): return np.array(words_from_text(s)) -@functools.lru_cache(maxsize=2 ** 12) +@functools.lru_cache(maxsize=2**12) def word_difference_score(s1, s2): """Returns the number of words that are non-overlapping between s1 and s2.""" diff --git a/textattack/metrics/attack_metrics/words_perturbed.py b/textattack/metrics/attack_metrics/words_perturbed.py index d4b12824..6104de1b 100644 --- a/textattack/metrics/attack_metrics/words_perturbed.py +++ b/textattack/metrics/attack_metrics/words_perturbed.py @@ -31,7 +31,7 @@ def calculate(self, results): self.total_attacks = len(self.results) self.all_num_words = np.zeros(len(self.results)) self.perturbed_word_percentages = np.zeros(len(self.results)) - self.num_words_changed_until_success = np.zeros(2 ** 16) + self.num_words_changed_until_success = np.zeros(2**16) self.max_words_changed = 0 for i, result in enumerate(self.results): diff --git a/textattack/search_methods/greedy_word_swap_wir.py b/textattack/search_methods/greedy_word_swap_wir.py index d859fa12..b33d33eb 100644 --- a/textattack/search_methods/greedy_word_swap_wir.py +++ b/textattack/search_methods/greedy_word_swap_wir.py @@ -31,8 +31,9 @@ class GreedyWordSwapWIR(SearchMethod): model_wrapper: model wrapper used for gradient-based ranking """ - def __init__(self, wir_method="unk"): + def __init__(self, wir_method="unk", unk_token="[UNK]"): self.wir_method = wir_method + self.unk_token = unk_token def _get_index_order(self, initial_text): """Returns word indices of ``initial_text`` in descending order of @@ -41,7 +42,8 @@ def _get_index_order(self, initial_text): if self.wir_method == "unk": leave_one_texts = [ - initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text) + initial_text.replace_word_at_index(i, self.unk_token) + for i in range(len_text) ] leave_one_results, search_over = self.get_goal_results(leave_one_texts) index_scores = np.array([result.score for result in leave_one_results]) @@ -49,7 +51,8 @@ def _get_index_order(self, initial_text): elif self.wir_method == "weighted-saliency": # first, compute word saliency leave_one_texts = [ - initial_text.replace_word_at_index(i, "[UNK]") for i in range(len_text) + initial_text.replace_word_at_index(i, self.unk_token) + for i in range(len_text) ] leave_one_results, search_over = self.get_goal_results(leave_one_texts) saliency_scores = np.array([result.score for result in leave_one_results]) diff --git a/textattack/shared/attacked_text.py b/textattack/shared/attacked_text.py index 13f8c55f..a56d0cc0 100644 --- a/textattack/shared/attacked_text.py +++ b/textattack/shared/attacked_text.py @@ -138,7 +138,8 @@ def pos_of_word_index(self, desired_word_idx): """ if not self._pos_tags: sentence = Sentence( - self.text, use_tokenizer=textattack.shared.utils.words_from_text + self.text, + use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) textattack.shared.utils.flair_tag(sentence) self._pos_tags = sentence @@ -168,7 +169,8 @@ def ner_of_word_index(self, desired_word_idx, model_name="ner"): """ if not self._ner_tags: sentence = Sentence( - self.text, use_tokenizer=textattack.shared.utils.words_from_text + self.text, + use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(), ) textattack.shared.utils.flair_tag(sentence, model_name) self._ner_tags = sentence @@ -467,6 +469,10 @@ def generate_new_attacked_text(self, new_words): # Add substitute word(s) to new sentence. perturbed_text += adv_word_seq perturbed_text += original_text # Add all of the ending punctuation. + + # Add pointer to self so chain of replacements can be reconstructed. + new_attack_attrs["prev_attacked_text"] = self + # Reform perturbed_text into an OrderedDict. perturbed_input_texts = perturbed_text.split(AttackedText.SPLIT_TOKEN) perturbed_input = OrderedDict( diff --git a/textattack/shared/utils/misc.py b/textattack/shared/utils/misc.py index 0f65ffd8..18511f80 100644 --- a/textattack/shared/utils/misc.py +++ b/textattack/shared/utils/misc.py @@ -7,7 +7,9 @@ import textattack -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = os.environ.get( + "TA_DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu") +) def html_style_from_dict(style_dict): diff --git a/textattack/shared/utils/strings.py b/textattack/shared/utils/strings.py index 91431f49..0cd278bf 100644 --- a/textattack/shared/utils/strings.py +++ b/textattack/shared/utils/strings.py @@ -1,5 +1,6 @@ import string +import flair import jieba import pycld2 as cld2 @@ -106,6 +107,11 @@ def words_from_text(s, words_to_ignore=[]): return words +class TextAttackFlairTokenizer(flair.data.Tokenizer): + def tokenize(self, text: str): + return words_from_text(text) + + def default_class_repr(self): if hasattr(self, "extra_repr_keys"): extra_params = []