Skip to content

Commit

Permalink
minor black format updates
Browse files Browse the repository at this point in the history
  • Loading branch information
qiyanjun committed Jun 6, 2022
1 parent 3e30662 commit 82d1ec5
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions textattack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
constraints: List[Union[Constraint, PreTransformationConstraint]],
transformation: Transformation,
search_method: SearchMethod,
transformation_cache_size=2**15,
constraint_cache_size=2**15,
transformation_cache_size=2 ** 15,
constraint_cache_size=2 ** 15,
):
"""Initialize an attack object.
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ class _CommandLineAttackArgs:
interactive: bool = False
parallel: bool = False
model_batch_size: int = 32
model_cache_size: int = 2**18
constraint_cache_size: int = 2**18
model_cache_size: int = 2 ** 18
constraint_cache_size: int = 2 ** 18

@classmethod
def _add_parser_args(cls, parser):
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

self.max_diff = max_diff
self.model_name = model_name
self._reference_score_cache = lru.LRU(2**10)
self._reference_score_cache = lru.LRU(2 ** 10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH
)

self.lm_cache = lru.LRU(2**18)
self.lm_cache = lru.LRU(2 ** 18)

def clear_cache(self):
self.lm_cache.clear()
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.language_nltk = language_nltk
self.language_stanza = language_stanza

self._pos_tag_cache = lru.LRU(2**14)
self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, embedding=None, **kwargs):
def clear_cache(self):
self._get_thought_vector.cache_clear()

@functools.lru_cache(maxsize=2**10)
@functools.lru_cache(maxsize=2 ** 10)
def _get_thought_vector(self, text):
"""Sums the embeddings of all the words in ``text`` into a "thought
vector"."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/goal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2**20,
model_cache_size=2 ** 20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model_wrapper.model.__class__
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/text/minimize_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def extra_repr_keys(self):
return ["maximizable", "target_bleu"]


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
Expand Down
4 changes: 2 additions & 2 deletions textattack/goal_functions/text/non_overlapping_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def _get_score(self, model_output, _):
return num_words_diff / len(get_words_cached(self.ground_truth_output))


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def get_words_cached(s):
return np.array(words_from_text(s))


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def word_difference_score(s1, s2):
"""Returns the number of words that are non-overlapping between s1 and
s2."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/metrics/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def calculate(self, results):
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2**16)
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.max_words_changed = 0

for i, result in enumerate(self.results):
Expand Down

0 comments on commit 82d1ec5

Please sign in to comment.