Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent word swap #752

Merged
merged 6 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ def test_word_swap_change_location():
assert entity_original == entity_augmented


def test_word_swap_change_location_consistent():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeLocation

augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True))
s = "I am in New York. I love living in New York."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)

assert entity_original == entity_augmented
assert s_augmented[0].count("New York") == 0


def test_word_swap_change_name():
from flair.data import Sentence
from flair.models import SequenceTagger
Expand All @@ -59,6 +87,34 @@ def test_word_swap_change_name():
assert entity_original == entity_augmented


def test_word_swap_change_name_consistent():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeName

augmenter = Augmenter(transformation=WordSwapChangeName(consistent=True))
s = "My name is Anthony Davis. Anthony Davis plays basketball."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)

assert entity_original == entity_augmented
assert s_augmented[0].count("Anthony") == 0 or s_augmented[0].count("Davis") == 0


def test_chinese_morphonym_character_swap():
from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps.chn_transformations import (
Expand Down
61 changes: 49 additions & 12 deletions textattack/transformations/word_swaps/word_swap_change_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Word Swap by Changing Location
-------------------------------
"""
from collections import defaultdict

import more_itertools as mit
import numpy as np

Expand All @@ -25,12 +27,15 @@ def idx_to_words(ls, words):


class WordSwapChangeLocation(WordSwap):
def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
def __init__(
self, n=3, confidence_score=0.7, language="en", consistent=False, **kwargs
):
"""Transformation that changes recognized locations of a sentence to
another location that is given in the location map.

:param n: Number of new locations to generate
:param confidence_score: Location will only be changed if it's above the confidence score
:param consistent: Whether to change all instances of the same location to the same new location

>>> from textattack.transformations import WordSwapChangeLocation
>>> from textattack.augmentation import Augmenter
Expand All @@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
self.n = n
self.confidence_score = confidence_score
self.language = language
self.consistent = consistent

def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
Expand All @@ -64,26 +70,53 @@ def _get_transformations(self, current_text, indices_to_modify):
location_idx = [list(group) for group in mit.consecutive_groups(location_idx)]
location_words = idx_to_words(location_idx, words)

if self.consistent:
location_to_indices = defaultdict(list)
for idx, location in location_words:
location_to_indices[self._capitalize(location)].append(idx[0])

transformed_texts = []
for location in location_words:
idx = location[0]
word = location[1].capitalize()
word = self._capitalize(location[1])
replacement_words = self._get_new_location(word)
for r in replacement_words:
if r == word:
continue
text = current_text

# if original location is more than a single word, remain only the starting word
if len(idx) > 1:
index = idx[1]
for i in idx[1:]:
text = text.delete_word_at_index(index)
if self.consistent:
# If we're doing consistent replacements, only replace the word
# if it hasn't already been replaced in a previous iteration
if word not in location_to_indices:
continue

indices_to_delete = []
if len(idx) > 1:
for i in location_to_indices[word]:
for j in range(1, len(idx)):
indices_to_delete.append(i + j)

transformed_texts.append(
current_text.replace_words_at_indices(
location_to_indices[word] + indices_to_delete,
([r] * len(location_to_indices[word]))
+ ([""] * len(indices_to_delete)),
)
)

# Delete this word to mark it as replaced
del location_to_indices[word]
else:
# If the original location is more than a single word, keep only the starting word
# and replace the starting word with the new word
indices_to_delete = idx[1:]
transformed_texts.append(
current_text.replace_words_at_indices(
[idx[0]] + indices_to_delete,
[r] + [""] * len(indices_to_delete),
)
)

# replace the starting word with new location
text = text.replace_word_at_index(idx[0], r)

transformed_texts.append(text)
return transformed_texts

def _get_new_location(self, word):
Expand All @@ -101,3 +134,7 @@ def _get_new_location(self, word):
elif word in NAMED_ENTITIES["city"]:
return np.random.choice(NAMED_ENTITIES["city"], self.n)
return []

def _capitalize(self, string):
"""Capitalizes all words in the string."""
return " ".join(word.capitalize() for word in string.split())
31 changes: 30 additions & 1 deletion textattack/transformations/word_swaps/word_swap_change_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
-------------------------------
"""

from collections import defaultdict

import numpy as np

from textattack.shared.data import PERSON_NAMES
Expand All @@ -18,6 +20,7 @@ def __init__(
last_only=False,
confidence_score=0.7,
language="en",
consistent=False,
**kwargs
):
"""Transforms an input by replacing names of recognized name entity.
Expand All @@ -26,6 +29,7 @@ def __init__(
:param first_only: Whether to change first name only
:param last_only: Whether to change last name only
:param confidence_score: Name will only be changed when it's above confidence score
:param consistent: Whether to change all instances of the same name to the same new name
>>> from textattack.transformations import WordSwapChangeName
>>> from textattack.augmentation import Augmenter

Expand All @@ -42,6 +46,7 @@ def __init__(
self.last_only = last_only
self.confidence_score = confidence_score
self.language = language
self.consistent = consistent

def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
Expand All @@ -52,14 +57,38 @@ def _get_transformations(self, current_text, indices_to_modify):
else:
model_name = "flair/ner-multi-fast"

if self.consistent:
word_to_indices = defaultdict(list)
for i in indices_to_modify:
word_to_replace = current_text.words[i].capitalize()
word_to_indices[word_to_replace].append(i)

for i in indices_to_modify:
word_to_replace = current_text.words[i].capitalize()
# If we're doing consistent replacements, only replace the word
# if it hasn't already been replaced in a previous iteration
if self.consistent and word_to_replace not in word_to_indices:
continue
word_to_replace_ner = current_text.ner_of_word_index(i, model_name)

replacement_words = self._get_replacement_words(
word_to_replace, word_to_replace_ner
)

for r in replacement_words:
transformed_texts.append(current_text.replace_word_at_index(i, r))
if self.consistent:
transformed_texts.append(
current_text.replace_words_at_indices(
word_to_indices[word_to_replace],
[r] * len(word_to_indices[word_to_replace]),
)
)
else:
transformed_texts.append(current_text.replace_word_at_index(i, r))

# Delete this word to mark it as replaced
if self.consistent and len(replacement_words) != 0:
del word_to_indices[word_to_replace]

return transformed_texts

Expand Down
Loading