Skip to content

Commit

Permalink
Merge pull request #665 from VijayKalmath/Optimize-GreedyWordSwapWir-…
Browse files Browse the repository at this point in the history
…withInputColumnModification

Optimize GreedyWordSwapWIR with pretransformation constraints
  • Loading branch information
jxmorris12 committed Jun 16, 2022
2 parents e072794 + 7afc7c0 commit 6423f6e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ lovingly photographed in the manner of a golden book sprung to [[ife]] , stuart
| Attack success rate: | 100.0% |
| Average perturbed word %: | 45.0% |
| Average num. words per input: | 12.0 |
| Avg num queries: | 25.0 |
| Avg num queries: | 22.0 |
+-------------------------------+--------+
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ mostly , [goldbacher] just lets her complicated characters be [[haphazard]] , co
| Attack success rate: | 100.0% |
| Average perturbed word %: | 17.56% |
| Average num. words per input: | 16.25 |
| Avg num queries: | 45.5 |
| Avg num queries: | 38.5 |
+-------------------------------+--------+
2 changes: 1 addition & 1 deletion tests/sample_outputs/run_attack_transformers_datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beaut
| Attack success rate: | 66.67% |
| Average perturbed word %: | 30.95% |
| Average num. words per input: | 8.33 |
| Avg num queries: | 22.67 |
| Avg num queries: | 20.0 |
+-------------------------------+--------+
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beaut
| Attack success rate: | 66.67% |
| Average perturbed word %: | 30.95% |
| Average num. words per input: | 8.33 |
| Avg num queries: | 22.67 |
| Avg num queries: | 20.0 |
| Average Original Perplexity: | 734/.*/ |
| Average Attack Perplexity: | 1744/.*/|
| Average Attack USE Score: | 0.76 |
+-------------------------------+---------+
+-------------------------------+---------+
25 changes: 25 additions & 0 deletions textattack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def __init__(
# by the attack class when checking whether to skip the sample
self.search_method.get_goal_results = self.goal_function.get_results

# Give search method access to get indices which need to be ordered / searched
self.search_method.get_indices_to_order = self.get_indices_to_order

self.search_method.filter_transformations = self.filter_transformations

def clear_cache(self, recursive=True):
Expand Down Expand Up @@ -233,6 +236,28 @@ def to_cuda(obj):

to_cuda(self)

def get_indices_to_order(self, current_text, **kwargs):
"""Applies ``pre_transformation_constraints`` to ``text`` to get all
the indices that can be used to search and order.
Args:
current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
Returns:
The length and the filtered list of indices which search methods can use to search/order.
"""

indices_to_order = self.transformation(
current_text,
pre_transformation_constraints=self.pre_transformation_constraints,
return_indices=True,
**kwargs,
)

len_text = len(indices_to_order)

# Convert indices_to_order to list for easier shuffling later
return len_text, list(indices_to_order)

def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
"""Applies ``self.transformation`` to ``text``, then filters the list
of possible transformations through the applicable constraints.
Expand Down
22 changes: 11 additions & 11 deletions textattack/search_methods/greedy_word_swap_wir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def __init__(self, wir_method="unk", unk_token="[UNK]"):
def _get_index_order(self, initial_text):
"""Returns word indices of ``initial_text`` in descending order of
importance."""
len_text = len(initial_text.words)

len_text, indices_to_order = self.get_indices_to_order(initial_text)

if self.wir_method == "unk":
leave_one_texts = [
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
Expand All @@ -52,7 +53,7 @@ def _get_index_order(self, initial_text):
# first, compute word saliency
leave_one_texts = [
initial_text.replace_word_at_index(i, self.unk_token)
for i in range(len_text)
for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
saliency_scores = np.array([result.score for result in leave_one_results])
Expand All @@ -63,7 +64,7 @@ def _get_index_order(self, initial_text):

# compute the largest change in score we can find by swapping each word
delta_ps = []
for idx in range(len_text):
for idx in indices_to_order:
transformed_text_candidates = self.get_transformations(
initial_text,
original_text=initial_text,
Expand Down Expand Up @@ -93,19 +94,19 @@ def _get_index_order(self, initial_text):

elif self.wir_method == "delete":
leave_one_texts = [
initial_text.delete_word_at_index(i) for i in range(len_text)
initial_text.delete_word_at_index(i) for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])

elif self.wir_method == "gradient":
victim_model = self.get_victim_model()
index_scores = np.zeros(initial_text.num_words)
index_scores = np.zeros(len_text)
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
gradient = grad_output["gradient"]
word2token_mapping = initial_text.align_with_model_tokens(victim_model)
for i, word in enumerate(initial_text.words):
matched_tokens = word2token_mapping[i]
for i, index in enumerate(indices_to_order):
matched_tokens = word2token_mapping[index]
if not matched_tokens:
index_scores[i] = 0.0
else:
Expand All @@ -115,14 +116,14 @@ def _get_index_order(self, initial_text):
search_over = False

elif self.wir_method == "random":
index_order = np.arange(len_text)
index_order = indices_to_order
np.random.shuffle(index_order)
search_over = False
else:
raise ValueError(f"Unsupported WIR method {self.wir_method}")

if self.wir_method != "random":
index_order = (-index_scores).argsort()
index_order = np.array(indices_to_order)[(-index_scores).argsort()]

return index_order, search_over

Expand All @@ -131,7 +132,6 @@ def perform_search(self, initial_result):

# Sort words by order of importance
index_order, search_over = self._get_index_order(attacked_text)

i = 0
cur_result = initial_result
results = None
Expand Down
7 changes: 7 additions & 0 deletions textattack/transformations/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __call__(
pre_transformation_constraints=[],
indices_to_modify=None,
shifted_idxs=False,
return_indices=False,
):
"""Returns a list of all possible transformations for ``current_text``.
Applies the ``pre_transformation_constraints`` then calls
Expand All @@ -32,6 +33,8 @@ def __call__(
``SearchMethod``.
shifted_idxs (bool): Whether indices could have been shifted from
their original position in the text.
return_indices (bool): Whether the function returns indices_to_modify
instead of the transformed_texts.
"""
if indices_to_modify is None:
indices_to_modify = set(range(len(current_text.words)))
Expand All @@ -47,6 +50,10 @@ def __call__(

for constraint in pre_transformation_constraints:
indices_to_modify = indices_to_modify & constraint(current_text, self)

if return_indices:
return indices_to_modify

transformed_texts = self._get_transformations(current_text, indices_to_modify)
for text in transformed_texts:
text.attack_attrs["last_transformation"] = self
Expand Down

0 comments on commit 6423f6e

Please sign in to comment.