Skip to content

Commit

Permalink
Merge pull request #591 from QData/fix-StopIteration-bug
Browse files Browse the repository at this point in the history
Fix CLARE StopIteration Bug
  • Loading branch information
qiyanjun committed Dec 3, 2021
2 parents a0ec175 + dbd31e8 commit 5f825f6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
WordInsertionMaskedLM Class
-------------------------------
"""

import re

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
Expand Down Expand Up @@ -162,7 +162,7 @@ def _get_transformations(self, current_text, indices_to_modify):
word_at_index = current_text.words[index_to_modify]
for word in new_words[i]:
word = word.strip("臓")
if word != word_at_index:
if word != word_at_index and re.search("[a-zA-Z]", word):
transformed_texts.append(
current_text.insert_text_before_word_index(
index_to_modify, word
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
------------------------------------------------
"""
import re

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
Expand Down Expand Up @@ -169,7 +170,7 @@ def _get_transformations(self, current_text, indices_to_modify):
word_at_index = current_text.words[index_to_modify]
for word in merged_words[i]:
word = word.strip("臓")
if word != word_at_index:
if word != word_at_index and re.search("[a-zA-Z]", word):
temp_text = current_text.delete_word_at_index(index_to_modify + 1)
transformed_texts.append(
temp_text.replace_word_at_index(index_to_modify, word)
Expand Down
7 changes: 6 additions & 1 deletion textattack/transformations/word_swaps/word_swap_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import itertools
import re

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
Expand Down Expand Up @@ -290,7 +291,11 @@ def _get_transformations(self, current_text, indices_to_modify):
word_at_index = current_text.words[index_to_modify]
for word in replacement_words[i]:
word = word.strip("臓")
if word != word_at_index and len(utils.words_from_text(word)) == 1:
if (
word != word_at_index
and re.search("[a-zA-Z]", word)
and len(utils.words_from_text(word)) == 1
):
transformed_texts.append(
current_text.replace_word_at_index(index_to_modify, word)
)
Expand Down

0 comments on commit 5f825f6

Please sign in to comment.