In [None]:
import torch
import argparse
import pickle as pkl
import scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import dgl
import random
import re
import copy 
from HAN_RoHe.model import *
from HAN_RoHe.utils import *

In [None]:
dataname = 'acm'
settings_pap = {'T':2, 'device':2}# acm
settings_psp = {'T':5, 'device':2}# acm
settings = [settings_pap, settings_psp]
device = 2
meta_paths_dict = {'acm':[['pa','ap'],['pf','fp']], \
                    'dblp':[['ap','pa'],['ap','pc','cp','pa'],['ap','pt','tp','pa']],\
                    'aminer':[['pa','ap'],['pr','rp']]}
#1.init
args = {}
args['seed'] = 2
args['hetero'] = True
args['log_dir'] = 'results'
args = setup(args)
g, hete_adjs, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask = load_acm_raw(False)
g = g.to(device)
if hasattr(torch, 'BoolTensor'):
    train_mask = train_mask.bool()
    val_mask = val_mask.bool()
    test_mask = test_mask.bool()
features = features.to(device)
labels = labels.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

#2.generate transition matrix
def get_transition(given_hete_adjs, metapath_info):
    # transition
    hete_adj_dict_tmp = {}
    for key in given_hete_adjs.keys():
        deg = given_hete_adjs[key].sum(1)
        hete_adj_dict_tmp[key] = given_hete_adjs[key]/(np.where(deg > 0, deg, 1))#make sure deg>0
    homo_adj_list = []
    for i in range(len(metapath_info)):
        adj = hete_adj_dict_tmp[metapath_info[i][0]]
        for etype in metapath_info[i][1:]:
            adj = adj.dot(hete_adj_dict_tmp[etype])
        homo_adj_list.append(sp.csc_matrix(adj))
    return homo_adj_list
trans_adj_list = get_transition(hete_adjs, meta_paths_dict[dataname])   
for i in range(len(trans_adj_list)):
    settings[i]['device'] = device
    settings[i]['TransM'] = trans_adj_list[i]


#3.train model
model = HAN(meta_paths=meta_paths_dict[dataname],
            in_size=features.shape[1],
            hidden_size=args['hidden_units'],
            out_size=num_classes,
            num_heads=args['num_heads'],
            dropout=args['dropout'],
            settings = settings).to(device)

stopper = EarlyStopping(patience=args['patience'])
# stopper.filename = 'atk_result/mid_routdglHan_hyper_'+dataname+'.pth'
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
                             weight_decay=args['weight_decay'])
for epoch in range(args['num_epochs']):
    model.train()
    logits = model(g, features)
    loss = loss_fcn(logits[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])
    val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
    early_stop = stopper.step(val_loss.data.item(), val_acc, model)
    print(epoch,"|",val_micro_f1, val_macro_f1)
    if early_stop:
        break
        
#3.test model
stopper.load_checkpoint(model)
test_loss, _, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
print("@@@@test:", test_micro_f1, test_macro_f1)  

In [None]:
# load target node ID:
tar_idx = []
for i in range(1): # can attack 500 target nodes by seting range(5)
    with open('data/preprocess/target_nodes/'+dataname+'_r_target' + str(i) + '.pkl', 'rb') as f:
        tar_tmp = np.sort(pkl.load(f))
    tar_idx.extend(tar_tmp) 

# evaluate result
with torch.no_grad():
    logits = model(g, features)
logits_clean = logits[tar_idx]
labels_clean = labels[tar_idx]
_, tar_micro_f1_clean, tar_macro_f1_clean = score(logits_clean, labels_clean)
print("Clean data:  Micro-F1:", tar_micro_f1_clean, " Macro-F1:",tar_macro_f1_clean)

In [None]:
n_perturbation = 1
adv_filename = 'data/generated_attacks/adv_acm_pap_pa_'+str(n_perturbation)+'.pkl'
tar_mask = get_binary_mask(train_mask.shape[0], tar_idx)
micro_f1_list_adv = []
macro_f1_list_adv = []
# load adversarial attacks for each target node
with open(adv_filename,'rb') as f:
    modified_opt = pkl.load(f)
#2.attack
logits_adv = []
labels_adv = []
for items in modified_opt:
    #2.1 init
    target_node = items[0]
    del_list = items[2]
    add_list = items[3]
    if target_node not in tar_idx:
        continue
    #2.2 attack adjs
    mod_hete_adj_dict = copy.deepcopy(hete_adjs)
    for edge in del_list:
        mod_hete_adj_dict['pa'][edge[0],edge[1]] = 0
        mod_hete_adj_dict['ap'][edge[1],edge[0]] = 0
    for edge in add_list:
        mod_hete_adj_dict['pa'][edge[0],edge[1]] = 1
        mod_hete_adj_dict['ap'][edge[1],edge[0]] = 1
    trans_adj_list = get_transition(mod_hete_adj_dict, meta_paths_dict[dataname])   
    for i in range(len(trans_adj_list)):
        model.layers[0].gat_layers[i].settings['device'] = device
        model.layers[0].gat_layers[i].settings['TransM'] = trans_adj_list[i]
    hg_atk = get_hg(dataname, mod_hete_adj_dict).to(device)
    #2.3 run model
    with torch.no_grad():
        logits = model(hg_atk, features)
    #2.4 evaluate
    logits_adv.append(logits[np.array([[target_node]])])
    labels_adv.append(labels[np.array([[target_node]])])
logits_adv = torch.cat(logits_adv,0)
labels_adv = torch.cat(labels_adv)
_, tar_micro_f1_atk, tar_macro_f1_atk = score(logits_adv, labels_adv)
print("Attacked data:  Micro-F1:", tar_micro_f1_atk, " Macro-F1:",tar_macro_f1_atk)

In [None]:
import matplotlib.pyplot as plt
y_testAcc_name = 'Results of HAN-RoHe (%)'
plt.figure(figsize=(9, 10))#dpi=xx
tick_label = ['Mi-F1','Ma-F1']
Y_clean = [tar_micro_f1_clean*100, tar_macro_f1_clean*100]
Y_attack = [tar_micro_f1_atk*100, tar_macro_f1_atk*100]
font_size = 35
X = np.arange(len(Y_attack))
plt.ylim(0,120)  
bar_width = 0.2
for x,y in zip(X,Y_clean):
    plt.text(x+0.05,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)
for x,y in zip(X,Y_attack):
    plt.text(x+0.25,y+0.005,'%d' %y, ha='center',va='bottom',fontsize=font_size)


clean = plt.bar(X, Y_clean, width=bar_width, color = 'g',edgecolor='black')
attack = plt.bar([x+0.2 for x in X], Y_attack, width=bar_width, color = 'red',edgecolor='black')
plt.ylabel(y_testAcc_name,{'family' : 'Times New Roman','weight':'bold', 'size':font_size})
plt.xticks(X, tick_label, size=font_size+3)
plt.yticks([0,20,40,60,80,100],fontsize = font_size)
font_legend={'family' : 'Times New Roman', 'size':font_size}
plt.legend([clean, attack],[r'clean', r'attack'],ncol=2,loc='upper center',bbox_to_anchor=(0.5,1.025), prop=font_legend)
plt.savefig("attack_HAN_RoHe.png", bbox_inches='tight',dpi=400,pad_inches=0.0)