-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_attack_v5.py
531 lines (463 loc) · 23.4 KB
/
main_attack_v5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
'''
tianrui liu @2020/10/12
Fun: model attack
concept: node_graph: (graph with node_id, 0-1472 for tweet, 1473....for user)
index_graph: (graph used in model, use user_map and loss_tweet_map
to transfer from node_graph to index_graph)
malicious users are immediate neighbors of target tweets
'''
import os
import sys
import time
import json
import argparse
import random
from tools import save_dict_to_json
import torch
import torch.nn.functional as F
import numpy as np
import os.path as pth
import pandas as pd
from time import sleep
from tqdm import tqdm
from datetime import datetime
from torch import nn
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import copy
import util
from util import str2bool
from GloveEmbed import _get_embedding
from early_stopping_attack_v2 import EarlyStopping
from dataset_attack import get_dataloader, data_split
from evaluate import evaluation4class
from tools import txt2iterable, iterable2txt
SOURCE_TWEET_NUM = 1472
## add args
parser = argparse.ArgumentParser(description='GAT for fake news detection')
parser.add_argument('--model', default='ensemble', type=str, help='ensemble/graph2tree/tree2graph/tree/graph')
parser.add_argument('--train', default=True, type=str2bool, help='train or traverse')
parser.add_argument('--patience', default=10, type=int, help='how long to wait after last time validation loss improved')
parser.add_argument('--freeze', default=False, type=str2bool, help='embedding freeze or not')
parser.add_argument('--load_ckpt', default=True, type=str2bool, help='load checkpoint')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--gpu', default=1, type=int, help='gpu id')
parser.add_argument('--epoches', default=90, type=int, help='maximum training epoches')
parser.add_argument('--batch_size', default=64, type=int, help='batch size')
parser.add_argument('--dropout', default=0.3, type=float, help='drop out rate')
parser.add_argument('--weight_decay', default=1e-6, type=float, help='weight decay')
parser.add_argument('--embed_dim', default=100, type=int, help='pretrain embed size')
parser.add_argument('--tweet_embed_size', default=100, type=int, help='tweet embed size')
parser.add_argument('--tree_hidden_size1', default=100, type=int, help='hidden size for TreeGCN')
parser.add_argument('--tree_hidden_size2', default=100, type=int, help='hidden size for TreeGCN')
parser.add_argument('--graph_hidden_size1', default=100, type=int, help='hidden size for GraphGCN')
parser.add_argument('--graph_hidden_size2', default=100, type=int, help='hidden size for GraphGCN')
parser.add_argument('--linear_hidden_size1', default=64, type=int, help='hidden size for fuly connected layer')
parser.add_argument('--direction', default='td', type=str, help='tree direction: topdown(td)/bottomup(bu)')
parser.add_argument('--user_feat_size', default=9, type=int, help = 'user features')
# parser.add_argument('--text_input_size', default=200, type=int, help = 'tweets and description input size')
parser.add_argument('--user_out_size', default=100, type=int, help = 'user description embed size')
# parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
parser.add_argument('--ckpt_name', default='best', type=str, help='load previous checkpoint. insert checkpoint filename')
parser.add_argument('--attack', default='True', type=str2bool, help='whether testing attack')
args = parser.parse_args()
args.user_out_size = int(args.tweet_embed_size/2)
# args.graph_hidden_size2 = args.tweet_embed_size
# import model
if args.model == 'ensemble':
print('---using ensemble---')
from model_ensemble import Net
elif args.model == 'graph':
print('---using graph---')
from model_graph import Net
elif args.model == 'tree':
print('---using tree---')
from model_tree import Net
elif args.model == 'graph2tree':
print('---using graph2tree---')
from model_graph2tree import Net
args.graph_hidden_size2 = args.tweet_embed_size
elif args.model == 'tree2graph':
print('---using tree2graph---')
from model_tree2graph import Net
args.tree_hidden_size2 = args.tweet_embed_size
else:
raise ValueError('parameter not found! input should among "ensemble/graph2tree/tree2graph/tree/graph"')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: {}{} ".format(device, str(args.gpu)))
def _load_word2index(word_file):
with open(word_file) as jsonfile:
word_map = json.load(jsonfile)
vocab_size = len(word_map)
return word_map, vocab_size
def _load_checkpoint():
ckp_path = os.path.join(args.ckpt_dir, args.ckpt_name)
if os.path.isfile(ckp_path):
checkpoint = torch.load(ckp_path)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
global_step = checkpoint['global_step']
curr_epoch = checkpoint['curr_epoch']
return model, optimizer, global_step, curr_epoch
def adjust_learning_rate(optim, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 15))
for param_group in optim.param_groups:
param_group['lr'] = lr
def _compute_accy_count(y_pred, y_labels):
return 1.0*y_pred.eq(y_labels).sum().item()
def _compute_accuracy(y_pred, y_labels):
return 1.0*y_pred.eq(y_labels).sum().item()/y_labels.size(0)
def _train_model(train_indices, test_indices, model):
time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
os.makedirs('logs/', exist_ok=True)
log_file = 'logs/' + time + 'args.log'
fw_log = open(log_file, "w")
json.dump(args.__dict__, fw_log, indent=2)
fw_log.write('\n')
# take record of parameters
parameter_record = pth.join('./parameter_record.md')
md = open(parameter_record, 'a')
model.train()
global_step = 0
max_iter = int(np.ceil(len(train_indices) / args.batch_size)) * args.epoches
writer = SummaryWriter()
pbar = tqdm(total=max_iter)
pbar.update(global_step)
early_stopping = EarlyStopping(args.patience)
for epoch in range(curr_epoch, args.epoches):
# adjust_learning_rate(args.lr, optimizer, epoch)
train_loader, _ = get_dataloader(args.batch_size, train_indices, test_indices, adjdict)
adjust_learning_rate(optimizer, epoch)
train_accy = []
for data in train_loader:
graph_node_features, graph_edge_index, user_feats = Variable(data[0]).to(device), Variable(data[1]).to(device), Variable(data[3]).to(device)
merged_tree_edge_index, merged_tree_feature, labels = Variable(data[4]).to(device), Variable(data[5]).to(device), Variable(data[6]).to(device)
indx = data[7].to(device)
# print('data', data)
global_step += 1
pbar.update(1)
optimizer.zero_grad()
output = model(user_feats, graph_node_features, graph_edge_index, merged_tree_feature, merged_tree_edge_index, indx)
# print(output)
# print("label: ",labels)
loss = loss_fun(output, labels)
loss.backward()
optimizer.step()
## compute accuracy
_, y_pred = output.max(dim=1)
accy = _compute_accuracy(y_pred, labels)
train_accy.append(accy)
if global_step % 10 == 0:
print("epoch: {} global step: {} loss: {:.5f} accuracy: {:.5f}"\
.format(epoch, global_step, loss.item(), accy))
fw_log.write("epoch: {} global step: {} loss: {:.5f} train accuracy: {:.5f}\n"\
.format(epoch, global_step, loss.item(), accy))
writer.add_scalar('Loss/train', loss, global_step)
if global_step % 10 == 0:
writer.flush()
## save checkpoint for best accuracy epoch
# if (epoch+1) % 1 == 0:
# _save_checkpoint(model, global_step, epoch)
# if (epoch+1) % 2 == 0:
# test_accy = _testing_model(test_loader,model)
Acc_all, Acc1, Prec1, Recll1, F1, Acc2, Prec2, Recll2, F2, Acc3, Prec3, Recll3, F3, Acc4, Prec4, Recll4, F4 = _testing_model(model,test_loader)
early_stopping(Acc_all, F1, F2, F3, F4, global_step, epoch, args.ckpt_dir, args.ckpt_name, model, optimizer)
fw_log.write('epoch: {} testing accuracy: {:4f}\n'.format(epoch, Acc_all))
fw_log.flush()
if early_stopping.early_stop:
# _save_checkpoint(model, early_stopping.global_step, early_stopping.epoch) #### false
print('early stop!')
break
model.train()
fw_log.write("BEST Accuracy: {:.4f}".format(early_stopping.best_accs))
print("BEST Accuracy: {:.4f}".format(early_stopping.best_accs))
md_write = '|{}| gpu: {} | {} | {} | {} | seed: {} | direction: {} | acc: {:.4f} | F1: {:.4f} | F2: {:.4f} | F3: {:.4f} | F4: {:.4f} | \n'.format(
str(time), str(args.gpu), str(args.model), str(args.batch_size), str(args.lr),
str(args.seed), str(args.direction), early_stopping.best_accs, early_stopping.F1, early_stopping.F2, early_stopping.F3, early_stopping.F4)
md.write(md_write)
md.close()
fw_log.close()
writer.close()
def _testing_model(model,test_loader):
model.eval()
all_pred = []
all_y = []
for data in test_loader:
graph_node_features, graph_edge_index, user_feats = Variable(data[0]).to(device), Variable(data[1]).to(device), Variable(data[3]).to(device)
merged_tree_edge_index, merged_tree_feature, labels = Variable(data[4]).to(device), Variable(data[5]).to(device), Variable(data[6]).to(device)
indx = data[7].to(device)
output = model(user_feats, graph_node_features, graph_edge_index, merged_tree_feature, merged_tree_edge_index, indx)
output = F.log_softmax(output, dim=1)
# print('output', output)
_, y_pred = output.max(dim=1)
all_pred += y_pred
all_y += labels
Acc_all, Acc1, Prec1, Recll1, F1,Acc2, Prec2, Recll2, F2,Acc3, Prec3, Recll3, F3,Acc4, Prec4, Recll4, F4 = evaluation4class(all_pred, all_y)
print('testing accuracy >>: ', Acc_all)
return Acc_all, Acc1, Prec1, Recll1, F1, Acc2, Prec2, Recll2, F2, Acc3, Prec3, Recll3, F3, Acc4, Prec4, Recll4, F4
# this will take record of all the add edge with its label and value, K=1 of beam search
greedy_search_attack_trace = dict()
add_edge_trace = open('add_edge_trace.md', 'a')
"""
greedy_search_attack_trace = {
(1,2000): (62, 34.378),
(2,6000): (61, 34.26),
(3,4000): (60,20/98)
...
}
"""
def node_graph_add_edge(node_graph, user, tweet):
"""
input: old graph, user(node), tweet(node())
outpyt: new_graph
"""
if not user and not tweet:
print('no add edge')
node_graph_new = copy.deepcopy(node_graph)
return node_graph_new
assert int(user) >= SOURCE_TWEET_NUM
if int(user) in node_graph[str(tweet)] and \
int(tweet) in node_graph[str(user)]:
print('{}-{} edge exists!'.format(user, tweet))
return node_graph
else:
node_graph_new = copy.deepcopy(node_graph)
node_graph_new[str(user)].append(int(tweet))
node_graph_new[str(tweet)].append(int(user))
return node_graph_new
def index_graph_add_edge(index_graph, bad_user_node, target_tweet_node):
"""
input: graph_index, user_nodeid, tweet_nodeid
output: new_graph_index
"""
if not bad_user_node and not target_tweet_node:
print('no add edge')
return index_graph
bad_user_index = user_map[str(bad_user_node)]
if str(target_tweet_node) in loss_tweet_map.keys():
tweet_index = loss_tweet_map[str(target_tweet_node)]
elif str(target_tweet_node) in no_loss_tweet_map.keys():
tweet_index = no_loss_tweet_map[str(target_tweet_node)]
else:
raise ValueError(target_tweet_node)
if (not bad_user_index) and (not tweet_index):
raise ValueError('no available user or tweet to add edge')
else:
index_graph_new = copy.deepcopy(index_graph)
# index_graph_new = copy.deepcopy(index_graph)
index_graph_new = torch.cat((index_graph_new,
torch.tensor([[bad_user_index, tweet_index], [tweet_index, bad_user_index]]).to(device)), 1
)
return index_graph_new
def _compute_accuracy(y_pred, y_labels):
return 1.0*y_pred.eq(y_labels).sum().item()/y_labels.size(0)
def calc_target_output(idx_graph, label_list):
model.eval()
with torch.no_grad():
output = model(user_feats, graph_node_features, idx_graph, merged_tree_feature, merged_tree_edge_index, indx)
output = F.softmax(output, dim=1)
# all_pred = torch.Tensor.cpu(output).detach().numpy()
all_pred = output.cpu().data.numpy()
_, label_pred = output.max(dim=1)
accy = _compute_accuracy(label_pred,labels)
rumor_score = 0
correct = 0
for i in range(len(label_list)):
if labels[label_list[i]] != 1:
raise ValueError(labels[label_list[i]])
if labels[label_list[i]] == 1 and label_pred[label_list[i]] == 1:
correct += 1
if labels[label_list[i]] == 1:
rumor_score += all_pred[label_list[i],1]
return correct, rumor_score, accy
def alter_graph(original_node_graph, original_index_graph, user_set, label_list):
"""
alter graph using Beam Search Algorithm
{(user1,tweet1): loss1, (user2,tweet2): loss2 ...}
"""
# graph_trace_cluster = graph_trace_list[-1] # cluster will be [(a,b), (c,d) ...] (K pairs)
# for edge in graph_trace_cluster:
node_graph = copy.deepcopy(original_node_graph)
index_graph = copy.deepcopy(original_index_graph)
best_node_graph = node_graph
best_index_graph = index_graph
chosen_edge = (None, None)
correct_label_origin, fake_value_origin, accy = calc_target_output(original_index_graph, label_list)
print('before attack: label:{}, fake value: {}'.format(correct_label_origin, fake_value_origin))
print('-'*89)
correct_label_best = correct_label_origin
fake_value_best = fake_value_origin
# bad_user_node = user_set[:10]
# for bad_user_node in tqdm(user_set[:3]):
# for tweet_node in tqdm(test_indices): # all test tweet indices
# for tweet_node in test_indices:
pbar = tqdm(total=len(target_tweet_set) * len(user_set))
pbar.update(0)
for tweet_node in target_tweet_set:
improve = False
for bad_user_node in user_set:
if int(bad_user_node) not in node_graph[str(tweet_node)] and int(tweet_node) not in node_graph[str(bad_user_node)]:
pbar.update(1)
new_node_graph = node_graph_add_edge(original_node_graph, bad_user_node, tweet_node)
index_graph_new = index_graph_add_edge(original_index_graph, bad_user_node, tweet_node)
correct_label_new, fake_value_new, accy = calc_target_output(index_graph_new, label_list)
if fake_value_best - fake_value_new > 0 :
fake_value_best = fake_value_new
correct_label_best = correct_label_new
chosen_edge = (bad_user_node, tweet_node)
best_index_graph = index_graph_new
best_node_graph = new_node_graph
improve = True
# if correct_label_new < correct_label_best:
# correct_label_best = correct_label_new
# label_chosen_edge = (bad_user_node, tweet_node)
# best_index_graph = index_graph_new
# best_node_graph = new_node_graph
# improve = True
# print after finishing all user
if improve:
print('origin correct_label: {}'.format(correct_label_origin))
print('origin label_score: {}'.format(fake_value_origin))
print("origin accuracy: {}".format(accy))
print('*' * 90)
# print('edge: {}, {}'.format(label_chosen_edge, value_chosen_edge))
print('new correct_label num: {}'.format(correct_label_best))
print('new fake value : {}'.format(fake_value_best))
print("new accuracy: {}".format(accy))
else:
print("no improve!")
print("{}: ({}, {})\n".format(chosen_edge, correct_label_best, fake_value_best))
print("--------------------------finish 1 add----------------------------")
attack_edge = "{}-{}".format(chosen_edge[0], chosen_edge[1])
greedy_search_attack_trace[attack_edge] = (correct_label_best, fake_value_best)
return best_node_graph, best_index_graph, improve
# chosen_edge: (user_node, tweet_node)
# select the best attacking edge according to output value only
# recalculate the label and value loss in best chosen edge
# best_node_graph = node_graph_add_edge(original_node_graph, chosen_edge[0], chosen_edge[1])
# best_index_graph = index_graph_add_edge(original_index_graph, chosen_edge[0], chosen_edge[1])
# add_edge_trace.write("{}: ({}, {})\n".format(chosen_edge, correct_label_final, fake_value_final))
if __name__ == '__main__':
print("-"*89)
print('start model training now!!')
print("-"*89)
# torch.cuda.set_device(args.gpu)
seed = 0
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
np.random.seed(seed)
# random.seed(seed)
# load vocab of tweets
TWEETS_WORD_FILE = pth.join('../load_data15_1473/tweets_words_mapping.json')
tweets_word_map, _ = _load_word2index(TWEETS_WORD_FILE)
glove_file = '../glove/glove.twitter.27B.{}d.txt'.format(args.embed_dim)
embed_dim = args.embed_dim
print("--load pretrain embedding now--")
tweet_embedding_matrix = _get_embedding(glove_file, tweets_word_map, embed_dim)
model = Net(args, tweet_embedding_matrix) # load model
model.to(device)
# train_indices, test_indices = data_split()
train_indices = txt2iterable(pth.join('train_indices.txt'))
test_indices = txt2iterable(pth.join('test_indices.txt'))
graph_connection_path = '../load_data15_1473/graph_connections2.json'
adjdict = util.read_dict_from_json(pth.join(graph_connection_path))
train_loader, test_loader = get_dataloader(args.batch_size, train_indices, test_indices, adjdict)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay)
loss_fun = nn.CrossEntropyLoss()
global_step = 0
curr_epoch = 0
# get all info from test dataloader
# print(bad_user_set)
if args.load_ckpt and args.attack:
bad_user_path = pth.join('../attack15/bad_user_score25.json')
bad_users_dict = util.read_dict_from_json(bad_user_path)
bad_users = list(bad_users_dict.keys())
bad_users = [i.split('.')[0] for i in bad_users]
all_user_path = pth.join('../load_data15_1473/filtered_user_profile2_encode2.csv')
df = pd.read_csv(all_user_path, lineterminator='\n')
node_id_list = df['node_id'].tolist()
user_id_list = df['user_id'].tolist()
node_id_list = [int(i) for i in node_id_list]
user_id_list = [str(i) for i in user_id_list]
user2node_dict = dict(zip(user_id_list, node_id_list))
bad_user_set = [user2node_dict[i] for i in bad_users]
assert bad_user_set is not None
for data in test_loader:
graph_node_features, original_index_graph, user_feats = Variable(data[0]).to(device), Variable(data[1]).to(device), Variable(data[3]).to(device)
merged_tree_edge_index, merged_tree_feature, labels = Variable(data[4]).to(device), Variable(data[5]).to(device), Variable(data[6]).to(device)
indx = data[7].to(device)
loss_tweet_map, user_map, no_loss_tweet_map = data[8], data[9], data[10]
# bad_user_set = [usr_id for usr_id in bad_user_set if str(usr_id) in user_map.keys()]
# print('bad user num: {}'.format(len(bad_user_set)))
model, optimizer, global_step, curr_epoch = _load_checkpoint()
# train_loader, test_loader = get_dataloader(args.batch_size, train_indices, test_indices, adjdict)
print("***loading checkpoint successfully***")
print("[checkpoint current epoch: {} and step: {}]".format(curr_epoch, global_step))
# _testing_model(model, test_loader)
# for loop for insert edge
index_graph = copy.deepcopy(original_index_graph)
node_graph = copy.deepcopy(adjdict)
# get beginning score
model.eval()
with torch.no_grad():
output = model(user_feats, graph_node_features, index_graph, merged_tree_feature, merged_tree_edge_index, indx)
output = F.softmax(output, dim=1)
# all_pred = torch.Tensor.cpu(output).detach().numpy()
all_pred = output.cpu().data.numpy()
_, label_pred = output.max(dim=1)
correct = 0
rumor_score = 0
lowest_acc = 1
lowest_idx = None
for i in range(len(labels)):
if labels[i] == 1 and label_pred[i] == 1:
if all_pred[i, 1] < lowest_acc:
lowest_acc = all_pred[i, 1]
lowest_idx = i
# correct_fake_label_list.append(i)
correct += 1
if labels[i] == 1:
rumor_score += all_pred[i,1]
print('beginning correct fake label', correct)
print('beginning rumor value', rumor_score)
# acurracy_list = [all_pred[i, 1] for i in correct_fake_label_list]
# correct_fake_label_list.sort()
# find target tweet node
assert lowest_idx is not None
reverse_loss_tweet_map = dict(zip(loss_tweet_map.values(), loss_tweet_map.keys())) # index2node
target_node = reverse_loss_tweet_map[int(lowest_idx)]
attack_user_list = adjdict[str(target_node)]
print("target_node: {}| attack_user_list number: {}".format(target_node, len(attack_user_list)))
# get all tweet list in the subgraph
test_tweet_set = set(test_indices)
test_user_set = set()
for tweet in test_indices:
user_list = [str(u) for u in adjdict[str(tweet)]]
test_user_set.update(user_list)
for u in test_user_set:
tweet_list = adjdict[str(u)]
test_tweet_set.update(tweet_list)
print("user num: {}".format(len(test_user_set)))
print("tweet num: {}".format(len(test_tweet_set)))
target_tweet_set = test_tweet_set # a set of node
bad_user_set = test_user_set
print('len(target_tweet_set): {}'.format(len(target_tweet_set)))
print('len(bad_user_set): {}'.format(len(bad_user_set)))
print('len(attack_user_list): {}'.format(len(attack_user_list)))
count_improve = 0
for i in range(10):
node_graph, index_graph, improve = alter_graph(node_graph, index_graph, attack_user_list, [lowest_idx])
# if improve:
# count_improve += 1
save_dict_to_json(greedy_search_attack_trace, os.path.join('greedy_search_attack_trace.json'))
if (not args.attack) and args.train:
_train_model(train_indices, test_indices, model)
# else:
# print("*****testing model now*****")
# accy = _testing_model()
add_edge_trace.close()