Skip to content

Commit

Permalink
Merge pull request #540 from QData/metric-module
Browse files Browse the repository at this point in the history
Fix metric-module Issue#532
  • Loading branch information
qiyanjun committed Nov 2, 2021
2 parents 74ba1ff + e09f81d commit e33f73b
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 81 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
sudo apt-get autoclean -y >/dev/null 2>&1
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
docker rmi $(docker image ls -aq) >/dev/null 2>&1
df -h
- name: Test with pytest
run: |
pytest tests -v
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ throws in enough clever and unexpected twists to make the formula feel fresh .
| Average perturbed word %: | 5.56% |
| Average num. words per input: | 15.5 |
| Avg num queries: | 1.33 |
| Average Original Perplexity: | 291.47 |
| Average Attack Perplexity: | 320.33 |
| Average Original Perplexity: | 291/.*/|
| Average Attack Perplexity: | 320/.*/|
| Average Attack USE Score: | 0.91 |
+-------------------------------+--------+
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beaut
| Average perturbed word %: | 30.95% |
| Average num. words per input: | 8.33 |
| Avg num queries: | 22.67 |
| Average Original Perplexity: | 1126.57 |
| Average Attack Perplexity: | 2823/.*/|
| Average Original Perplexity: | 734/.*/ |
| Average Attack Perplexity: | 1744/.*/|
| Average Attack USE Score: | 0.76 |
+-------------------------------+---------+
+-------------------------------+---------+
54 changes: 15 additions & 39 deletions tests/test_command_line/test_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,6 @@
"tests/sample_outputs/interactive_mode.txt",
),
#
# test loading an attack from the transformers model hub
#
(
"attack_from_transformers",
(
"textattack attack --model-from-huggingface "
"distilbert-base-uncased-finetuned-sst-2-english "
"--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 "
""
),
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
#
# test loading an attack from the transformers model hub and calculate perplexity and use
#
(
Expand Down Expand Up @@ -75,17 +62,6 @@
"tests/sample_outputs/run_attack_transformers_datasets.txt",
),
#
# test hotflip on 10 samples from LSTM MR
#
(
"run_attack_hotflip_lstm_mr_4",
(
"textattack attack --model lstm-mr --recipe hotflip "
"--num-examples 4 --num-examples-offset 3 "
),
"tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt",
),
#
# test hotflip on 10 samples from LSTM MR and calculate perplexity and use
#
(
Expand All @@ -106,21 +82,21 @@
),
"tests/sample_outputs/run_attack_deepwordbug_lstm_mr_2.txt",
),
#
# test: run_attack targeted classification of class 2 on BERT MNLI with log-to-csv
# and attack_n set, using the WordNet transformation and beam search with
# beam width 2, using language tool constraint, on 10 samples
# (takes about 72s)
#
(
"run_attack_targeted_mnli_misc",
(
"textattack attack --attack-n --goal-function targeted-classification^target_class=2 --log-to-csv "
"/tmp/textattack_test.csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation "
"word-swap-wordnet --constraints lang-tool repeat stopword --search beam-search^beam_width=2 "
),
"tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_log-to-csv_beamsearch2_attack_n.txt",
),
# #
# # test: run_attack targeted classification of class 2 on BERT MNLI with log-to-csv
# # and attack_n set, using the WordNet transformation and beam search with
# # beam width 2, using language tool constraint, on 10 samples
# # (takes about 72s)
# #
# (
# "run_attack_targeted_mnli_misc",
# (
# "textattack attack --attack-n --goal-function targeted-classification^target_class=2 --log-to-csv "
# "/tmp/textattack_test.csv --model bert-base-uncased-mnli --num-examples 2 --attack-n --transformation "
# "word-swap-wordnet --constraints lang-tool repeat stopword --search beam-search^beam_width=2 "
# ),
# "tests/sample_outputs/run_attack_targetedclassification2_wordnet_langtool_log-to-csv_beamsearch2_attack_n.txt",
# ),
#
# fmt: off
# test: run_attack untargeted classification on BERT MR using word embedding transformation and greedy-word-WIR search
Expand Down
58 changes: 58 additions & 0 deletions tests/test_metric_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
def test_perplexity():
from textattack.attack_results import SuccessfulAttackResult
from textattack.goal_function_results.classification_goal_function_result import (
ClassificationGoalFunctionResult,
)
from textattack.metrics.quality_metrics import Perplexity
from textattack.shared.attacked_text import AttackedText

sample_text = "hide new secretions from the parental units "
sample_atck_text = "Ehide enw secretions from the parental units "

results = [
SuccessfulAttackResult(
ClassificationGoalFunctionResult(
AttackedText(sample_text), None, None, None, None, None, None
),
ClassificationGoalFunctionResult(
AttackedText(sample_atck_text), None, None, None, None, None, None
),
)
]
ppl = Perplexity(model_name="distilbert-base-uncased").calculate(results)

assert int(ppl["avg_original_perplexity"]) == int(81.95)


def test_use():
import transformers

from textattack import AttackArgs, Attacker
from textattack.attack_recipes import DeepWordBugGao2018
from textattack.datasets import HuggingFaceDataset
from textattack.metrics.quality_metrics import USEMetric
from textattack.models.wrappers import HuggingFaceModelWrapper

model = transformers.AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
attack = DeepWordBugGao2018.build(model_wrapper)
dataset = HuggingFaceDataset("glue", "sst2", split="train")
attack_args = AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True,
)
attacker = Attacker(attack, dataset, attack_args)

results = attacker.attack_dataset()

usem = USEMetric().calculate(results)

assert usem["avg_attack_use_score"] == 0.76
100 changes: 63 additions & 37 deletions textattack/metrics/quality_metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,40 @@
Perplexity Metric:
======================
Class for calculating perplexity from AttackResults
"""

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric
import textattack.shared.utils


class Perplexity(Metric):
def __init__(self):
def __init__(self, model_name="gpt2"):
self.all_metrics = {}
self.original_candidates = []
self.successful_candidates = []
self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.ppl_model.eval()
self.max_length = self.ppl_model.config.n_positions

if model_name == "gpt2":
from transformers import GPT2LMHeadModel, GPT2Tokenizer

self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2")
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.ppl_model.eval()
self.max_length = self.ppl_model.config.n_positions
else:
from transformers import AutoModelForMaskedLM, AutoTokenizer

self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name)
self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.ppl_model.to(textattack.shared.utils.device)
self.ppl_model.eval()
self.max_length = self.ppl_model.config.max_position_embeddings

self.stride = 512

def calculate(self, results):
Expand All @@ -31,6 +44,27 @@ def calculate(self, results):
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results)
"""
self.results = results
self.original_candidates_ppl = []
Expand All @@ -52,42 +86,34 @@ def calculate(self, results):
ppl_orig = self.calc_ppl(self.original_candidates)
ppl_attack = self.calc_ppl(self.successful_candidates)

self.all_metrics["avg_original_perplexity"] = round(ppl_orig[0], 2)
self.all_metrics["original_perplexity_list"] = ppl_orig[1]
self.all_metrics["avg_original_perplexity"] = round(ppl_orig, 2)

self.all_metrics["avg_attack_perplexity"] = round(ppl_attack[0], 2)
self.all_metrics["attack_perplexity_list"] = ppl_attack[1]
self.all_metrics["avg_attack_perplexity"] = round(ppl_attack, 2)

return self.all_metrics

def calc_ppl(self, texts):

ppl_vals = []

with torch.no_grad():
for text in texts:
eval_loss = []
input_ids = torch.tensor(
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids.size(1), self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
input_ids_t = input_ids[:, begin_loc:end_loc].to(
textattack.shared.utils.device
)
target_ids = input_ids_t.clone()
target_ids[:, :-trg_len] = -100

outputs = self.ppl_model(input_ids_t, labels=target_ids)
log_likelihood = outputs[0] * trg_len

eval_loss.append(log_likelihood)

ppl_vals.append(
torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
text = " ".join(texts)
eval_loss = []
input_ids = torch.tensor(
self.ppl_tokenizer.encode(text, add_special_tokens=True)
).unsqueeze(0)
# Strided perplexity calculation from huggingface.co/transformers/perplexity.html
for i in range(0, input_ids.size(1), self.stride):
begin_loc = max(i + self.stride - self.max_length, 0)
end_loc = min(i + self.stride, input_ids.size(1))
trg_len = end_loc - i
input_ids_t = input_ids[:, begin_loc:end_loc].to(
textattack.shared.utils.device
)
target_ids = input_ids_t.clone()
target_ids[:, :-trg_len] = -100

outputs = self.ppl_model(input_ids_t, labels=target_ids)
log_likelihood = outputs[0] * trg_len

eval_loss.append(log_likelihood)

return sum(ppl_vals) / len(ppl_vals), ppl_vals
return torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
22 changes: 22 additions & 0 deletions textattack/metrics/quality_metrics/use.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,29 @@ def calculate(self, results):
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results)
"""

self.results = results

for i, result in enumerate(self.results):
Expand Down

0 comments on commit e33f73b

Please sign in to comment.