-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from QData/kuleshov
add Kuleshov recipe
- Loading branch information
Showing
60 changed files
with
396 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Kuleshov, V. et al. | ||
Generating Natural Language Adversarial Examples. | ||
https://openreview.net/pdf?id=r1QZ3zbAZ. | ||
""" | ||
|
||
from textattack.constraints.overlap import WordsPerturbed | ||
from textattack.constraints.grammaticality.language_models import GPT2 | ||
from textattack.constraints.semantics.sentence_encoders import ThoughtVector | ||
from textattack.goal_functions import UntargetedClassification | ||
from textattack.search_methods import GreedyWordSwap | ||
from textattack.transformations import WordSwapEmbedding | ||
|
||
def Kuleshov2017(model): | ||
# | ||
# "Specifically, in all experiments, we used a target of τ = 0.7, | ||
# a neighborhood size of N = 15, and parameters λ_1 = 0.2 and δ = 0.5; we set | ||
# the syntactic bound to λ_2 = 2 nats for sentiment analysis" | ||
|
||
# | ||
# Word swap with top-15 counter-fitted embedding neighbors. | ||
# | ||
transformation = WordSwapEmbedding(max_candidates=15) | ||
# | ||
# Maximum of 50% of words perturbed (δ in the paper). | ||
# | ||
constraints = [] | ||
constraints.append( | ||
WordsPerturbed(max_percent=0.5) | ||
) | ||
# | ||
# Maximum thought vector Euclidean distance of λ_1 = 0.2. (eq. 4) | ||
# | ||
constraints.append( | ||
ThoughtVector(embedding_type='paragramcf', threshold=0.2, metric='max_euclidean') | ||
) | ||
# | ||
# | ||
# Maximum language model log-probability difference of λ_2 = 2. (eq. 5) | ||
# | ||
constraints.append( | ||
GPT2(max_log_prob_diff=2.0) | ||
) | ||
# | ||
# Goal is untargeted classification: reduce original probability score | ||
# to below τ = 0.7 (Algorithm 1). | ||
# | ||
goal_function = UntargetedClassification(model, target_max_score=0.7) | ||
# | ||
# Perform word substitution with a genetic algorithm. | ||
# | ||
attack = GreedyWordSwap(goal_function, constraints=constraints, | ||
transformation=transformation) | ||
|
||
return attack | ||
|
||
|
||
# GPT2(max_log_prob_diff=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .constraint import Constraint | ||
|
||
from . import grammaticality | ||
from . import semantics | ||
from . import syntax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from . import language_models | ||
|
||
from .language_tool import LanguageTool | ||
from .part_of_speech import PartOfSpeech |
3 changes: 3 additions & 0 deletions
3
textattack/constraints/grammaticality/language_models/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .google_language_model import Google1BillionWordsLanguageModel | ||
from .gpt2 import GPT2 | ||
from .language_model_constraint import LanguageModelConstraint |
1 change: 1 addition & 0 deletions
1
textattack/constraints/grammaticality/language_models/google_language_model/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .google_language_model import GoogleLanguageModel as Google1BillionWordsLanguageModel |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
51 changes: 51 additions & 0 deletions
51
textattack/constraints/grammaticality/language_models/gpt2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
from textattack.shared import utils | ||
from transformers import GPT2Tokenizer, GPT2LMHeadModel | ||
|
||
from .language_model_constraint import LanguageModelConstraint | ||
|
||
class GPT2(LanguageModelConstraint): | ||
""" A constraint based on the GPT-2 language model. | ||
from "Better Language Models and Their Implications" | ||
(openai.com/blog/better-language-models/) | ||
""" | ||
def __init__(self, **kwargs): | ||
self.model = GPT2LMHeadModel.from_pretrained('gpt2') | ||
self.model.to(utils.get_device()) | ||
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | ||
super().__init__(**kwargs) | ||
|
||
def get_log_probs_at_index(self, tokenized_text_list, word_index): | ||
""" Gets the probability of the word at index `word_index` according | ||
to GPT-2. Assumes that all items in `tokenized_text_list` | ||
have the same prefix up until `word_index`. | ||
""" | ||
prefix = tokenized_text_list[0].text_until_word_index(word_index) | ||
|
||
if not utils.has_letter(prefix): | ||
# This language model perplexity is not defined with respect to | ||
# a word without a prefix. If the prefix is null, just return the | ||
# log-probability 0.0. | ||
return torch.zeros(len(tokenized_text_list), dtype=torch.float) | ||
|
||
token_ids = self.tokenizer.encode(prefix) | ||
tokens_tensor = torch.tensor([token_ids]) | ||
tokens_tensor = tokens_tensor.to(utils.get_device()) | ||
|
||
with torch.no_grad(): | ||
outputs = self.model(tokens_tensor) | ||
predictions = outputs[0] | ||
|
||
probs = [] | ||
for tokenized_text in tokenized_text_list: | ||
next_word_ids = self.tokenizer.encode(tokenized_text.words[word_index]) | ||
next_word_prob = predictions[0, -1, next_word_ids[0]] | ||
probs.append(next_word_prob) | ||
|
||
return probs | ||
|
||
|
||
|
42 changes: 42 additions & 0 deletions
42
textattack/constraints/grammaticality/language_models/language_model_constraint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import math | ||
import torch | ||
|
||
from textattack.constraints import Constraint | ||
|
||
class LanguageModelConstraint(Constraint): | ||
""" | ||
Determines if two sentences have a swapped word that has a similar | ||
probability according to a language model. | ||
Args: | ||
max_log_prob_diff (float): the maximum difference in log-probability | ||
between x and x_adv | ||
""" | ||
|
||
def __init__(self, max_log_prob_diff=None): | ||
if max_log_prob_diff is None: | ||
raise ValueError('Must set max_log_prob_diff') | ||
self.max_log_prob_diff = max_log_prob_diff | ||
|
||
def get_log_probs_at_index(self, text_list, word_index): | ||
""" Gets the log-probability of items in `text_list` at index | ||
`word_index` according to a language model. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def __call__(self, x, x_adv, original_text=None): | ||
try: | ||
i = x_adv.attack_attrs['modified_word_index'] | ||
except AttributeError: | ||
raise AttributeError('Cannot apply language model constraint without `modified_word_index`') | ||
|
||
probs = self.get_log_probs_at_index((x, x_adv), i) | ||
if len(probs) != 2: | ||
raise ValueError(f'Error: get_log_probs_at_index returned {len(probs)} values for 2 inputs') | ||
x_prob, x_adv_prob = probs | ||
if self.max_log_prob_diff is None: | ||
x_prob, x_adv_prob = math.log(p1), math.log(p2) | ||
return abs(x_prob - x_adv_prob) <= self.max_log_prob_diff | ||
|
||
def extra_repr_keys(self): | ||
return ['max_log_prob_diff'] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.