-
Notifications
You must be signed in to change notification settings - Fork 0
/
textattack_main.py
183 lines (137 loc) · 5.24 KB
/
textattack_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Adversarial example generation scripts with TextAttack_v2
"""
import torch
from torch.utils.data import DataLoader
import numpy as np
import random
import os
import pandas as pd
from textattack.attack_results import SuccessfulAttackResult, FailedAttackResult, SkippedAttackResult
from model.textattack_model import CustomWrapper, print_function, AttackSummary, HuggingFaceModelWrapper
from model.model_adv import *
from model.load_model import *
from utils.utils import print_args, load_checkpoint
from utils.dataloader import text_dataloader
from datetime import timedelta
import time
from arguments import get_parser
import warnings
warnings.filterwarnings('ignore')
args = get_parser("attack")
if args.seed>-1:
SEED = args.seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
max_pert_ratio = int(args.max_rate*100)
n_ens = args.num_ensemble
egm = args.ens_grad_mask
gms = args.grad_mask_sample
mpr = max_pert_ratio
ql = args.q_limit
f_name = "_"+str(args.load_model)+"_"+args.dataset_type+".csv"
adv_path = os.path.join('./data/'+args.dataset+'_'+args.attack_method+f_name)
args.adv_path = adv_path
print("Load Dataset...")
if args.dataset_type =='test':
_, test = text_dataloader(args.dataset, args)
dataset = test.train_test_split(test_size=args.n_trials, seed=0)['test']
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=args.shuffle, num_workers=0)
else:
train, _ = text_dataloader(args.dataset, args)
dataset = train.train_test_split(test_size=args.n_trials, seed=0)['test']
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=args.shuffle, num_workers=0)
print_args(args)
print(f"Adv Path: {adv_path}")
print(f"Save Data: {args.save_data}")
print("----------------------------------")
print(f"Dataset Type: {args.dataset_type}")
print("----------------------------------")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = load_tokenizer(args)
args.pad_idx = tokenizer.pad_token_id
args.mask_idx = tokenizer.mask_token_id
args.cls_token = tokenizer.cls_token_id
args.sep_token = tokenizer.sep_token_id
if args.model_type=='rsmi':
print("RSMI is Loaded")
model = noisy_forward_loader(args)
model = SeqClsWrapper(model, args)
elif args.model_type=='base':
print("Base Model is Loaded")
model = load_base_model(args)
model = load_checkpoint(model, args.load_model, args.model_dir_path)
model.eval()
if args.model_type=='base':
model_wrapper = HuggingFaceModelWrapper(model, tokenizer, args)
else:
# Custom Wrapper for RSMI
print("RSMI Custom Wrapper")
model_wrapper = CustomWrapper(model, tokenizer, args)
model_wrapper.model.to(device)
# Attack Recipe
if args.attack_method == 'pwws':
from textattack.attack_recipes import PWWSRen2019
attack = PWWSRen2019.build(model_wrapper)
elif args.attack_method == 'textfooler':
from textattack.attack_recipes import TextFoolerJin2019
attack = TextFoolerJin2019.build(model_wrapper)
else:
raise Exception("Not Implemented")
attack.goal_function.maximizable = False
attack.goal_function.batch_size = args.adv_batch_size
print(attack)
num_successes = 0
num_skipped = 0
num_failed = 0
df_adv = pd.DataFrame()
attack_result = AttackSummary(args.max_seq_length)
# n_exception = 0
# n_pred_error = 0
# avg_pert = []
# avg_query = []
# avg_words = []
start_t_gen = time.perf_counter()
print("Start Attack...")
for batch_idx, batch in enumerate(dataloader):
label = batch['label'].item() # Ground truth label
orig = batch['text'][0] # Clen Text
text_len = len(orig.split(" "))
if text_len>args.max_seq_length:
orig = " ".join(orig.split(" ")[:args.max_seq_length])
text_len = args.max_seq_length
if args.q_limit>0:
q_limit = int(args.max_candidates*text_len)
attack.goal_function.query_budget = q_limit
result = attack.attack(orig, label)
attack_result(result)
n_query, n_pert, n_words = attack_result.text_analysis(result)
pert = result.perturbed_text()
pert_word_ratio =(n_pert/text_len)*100
if isinstance(result, SuccessfulAttackResult):
num_successes+=1
result_type = 'Successful'
elif isinstance(result, FailedAttackResult):
num_failed+=1
result_type = 'Failed'
elif isinstance(result, SkippedAttackResult):
num_skipped+=1
result_type = 'Skipped'
adv_dict = {'pert': pert, 'orig': orig, 'ground_truth_output': label,
'result_type': result_type, 'n_query': n_query, 'n_pert': n_pert, 'n_words': n_words}
df_adv = df_adv.append(adv_dict, ignore_index=True)
if batch_idx%10==0:
print(result.__str__(color_method='ansi'))
print_function(args, f_name, batch_idx, num_successes, num_failed, num_skipped)
print(attack_result.__str__(), flush=True)
eval_t = time.perf_counter()-start_t_gen
print_function(args, f_name, batch_idx, num_successes, num_failed, num_skipped)
print(attack_result.__str__(), flush=True)
print(f"Total Elapsed Time: {timedelta(seconds=eval_t)}", flush=True)
if args.save_data == True:
if not os.path.isdir('./data/'):
os.makedirs('./data/')
df_adv.to_csv(adv_path)
print(f"Save data...: {adv_path}", flush=True)