Skip to content

Commit

Permalink
Merge pull request #567 from QData/add-language-options-to-transforma…
Browse files Browse the repository at this point in the history
…tion

add language options
  • Loading branch information
qiyanjun committed Dec 16, 2021
2 parents 7675c17 + f349846 commit 33c9873
Show file tree
Hide file tree
Showing 12 changed files with 9,479 additions and 15 deletions.
59 changes: 59 additions & 0 deletions tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
def test_imports():
import flair
import torch

import textattack

del textattack, torch, flair


def test_word_swap_change_location():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeLocation

augmenter = Augmenter(transformation=WordSwapChangeLocation())
s = "I am in Dallas."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)
assert entity_original == entity_augmented


def test_word_swap_change_name():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeName

augmenter = Augmenter(transformation=WordSwapChangeName())
s = "My name is Anthony Davis."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)
assert entity_original == entity_augmented
2 changes: 2 additions & 0 deletions textattack/attack_recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@
from .pso_zang_2020 import PSOZang2020
from .checklist_ribeiro_2020 import CheckList2020
from .clare_li_2020 import CLARE2020
from .french_recipe import FrenchRecipe
from .spanish_recipe import SpanishRecipe
31 changes: 31 additions & 0 deletions textattack/attack_recipes/french_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from textattack import Attack
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import (
CompositeTransformation,
WordSwapChangeLocation,
WordSwapChangeName,
WordSwapWordNet,
)

from .attack_recipe import AttackRecipe


class FrenchRecipe(AttackRecipe):
@staticmethod
def build(model_wrapper):
transformation = CompositeTransformation(
[
WordSwapWordNet(language="fra"),
WordSwapChangeLocation(language="fra"),
WordSwapChangeName(language="fra"),
]
)
constraints = [RepeatModification(), StopwordModification("french")]
goal_function = UntargetedClassification(model_wrapper)
search_method = GreedyWordSwapWIR()
return Attack(goal_function, constraints, transformation, search_method)
31 changes: 31 additions & 0 deletions textattack/attack_recipes/spanish_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from textattack import Attack
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import (
CompositeTransformation,
WordSwapChangeLocation,
WordSwapChangeName,
WordSwapWordNet,
)

from .attack_recipe import AttackRecipe


class SpanishRecipe(AttackRecipe):
@staticmethod
def build(model_wrapper):
transformation = CompositeTransformation(
[
WordSwapWordNet(language="esp"),
WordSwapChangeLocation(language="esp"),
WordSwapChangeName(language="esp"),
]
)
constraints = [RepeatModification(), StopwordModification("spanish")]
goal_function = UntargetedClassification(model_wrapper)
search_method = GreedyWordSwapWIR()
return Attack(goal_function, constraints, transformation, search_method)
4 changes: 2 additions & 2 deletions textattack/shared/attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def pos_of_word_index(self, desired_word_idx):
f"Did not find word from index {desired_word_idx} in flair POS tag"
)

def ner_of_word_index(self, desired_word_idx):
def ner_of_word_index(self, desired_word_idx, model_name="ner"):
"""Returns the ner tag of the word at index `word_idx`.
Uses FLAIR ner tagger.
Expand All @@ -170,7 +170,7 @@ def ner_of_word_index(self, desired_word_idx):
sentence = Sentence(
self.text, use_tokenizer=textattack.shared.utils.words_from_text
)
textattack.shared.utils.flair_tag(sentence, "ner")
textattack.shared.utils.flair_tag(sentence, model_name)
self._ner_tags = sentence
flair_word_list, flair_ner_list = textattack.shared.utils.zip_flair_result(
self._ner_tags, "ner"
Expand Down

0 comments on commit 33c9873

Please sign in to comment.