Skip to content

Commit

Permalink
Merge pull request #700 from QData/oct-bug-fixes
Browse files Browse the repository at this point in the history
fix bugs in AT, strings (Chinese), TedMulti dataset
  • Loading branch information
jxmorris12 committed Oct 26, 2022
2 parents fea5cb2 + 4179531 commit 3681f9d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 28 deletions.
5 changes: 4 additions & 1 deletion textattack/datasets/helpers/ted_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TedMultiTranslationDataset(HuggingFaceDataset):
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/
"""

def __init__(self, source_lang="en", target_lang="de", split="test"):
def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=False):
self._dataset = datasets.load_dataset("ted_multi")[split]
self.examples = self._dataset["translations"]
language_options = set(self.examples[0]["language"])
Expand All @@ -34,6 +34,9 @@ def __init__(self, source_lang="en", target_lang="de", split="test"):
)
self.source_lang = source_lang
self.target_lang = target_lang
self.shuffled = shuffle
if shuffle:
self._dataset.shuffle()

def _format_raw_example(self, raw_example):
translations = np.array(raw_example["translation"])
Expand Down
20 changes: 4 additions & 16 deletions textattack/shared/attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,6 @@ def __eq__(self, other):
return False
if len(self.attack_attrs) != len(other.attack_attrs):
return False
for key in self.attack_attrs:
if key not in other.attack_attrs:
return False
elif isinstance(self.attack_attrs[key], np.ndarray):
if not (self.attack_attrs[key].shape == other.attack_attrs[key].shape):
return False
elif not (self.attack_attrs[key] == other.attack_attrs[key]).all():
return False
else:
if isinstance(self.attack_attrs[key], AttackedText):
if (
not self.attack_attrs[key]._text_input
== other.attack_attrs[key]._text_input
):
return False
return True

def __hash__(self):
Expand Down Expand Up @@ -576,7 +561,10 @@ def num_words(self):

@property
def newly_swapped_words(self):
return [self.words[i] for i in self.attack_attrs["newly_modified_indices"]]
return [
self.attack_attrs["prev_attacked_text"].words[i]
for i in self.attack_attrs["newly_modified_indices"]
]

def printable_text(self, key_color="bold", key_color_method=None):
"""Represents full text input. Adds field descriptions.
Expand Down
12 changes: 1 addition & 11 deletions textattack/shared/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import string

import flair
import jieba
import pycld2 as cld2

from .importing import LazyLoader

Expand Down Expand Up @@ -32,15 +30,7 @@ def add_indent(s_, numSpaces):
def words_from_text(s, words_to_ignore=[]):
"""Lowercases a string, removes all non-alphanumeric characters, and splits
into words."""
try:
isReliable, textBytesFound, details = cld2.detect(s)
if details[0][0] == "Chinese" or details[0][0] == "ChineseT":
seg_list = jieba.cut(s, cut_all=False)
s = " ".join(seg_list)
else:
s = " ".join(s.split())
except Exception:
s = " ".join(s.split())
s = " ".join(s.split())

homos = """˗৭Ȣ𝟕бƼᏎƷᒿlO`ɑЬϲԁе𝚏ɡհіϳ𝒌ⅼmոорԛⲅѕ𝚝սѵԝ×уᴢ"""
exceptions = """'-_*@"""
Expand Down

0 comments on commit 3681f9d

Please sign in to comment.