In [2]:
# !pip3 install scipy --upgrade
# !pip3 install cvxpy mosek

In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

import cvxpy as cp
import pickle as pkl
import numpy as np
import scipy.sparse as sp
from scipy.linalg import pinv, inv
import scipy.linalg as spl
from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [2]:
def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def propagation_matrix(adj, alpha=0.85, sigma=1, nodes=None):
    """
    Computes the propagation matrix  (1-alpha)(I - alpha D^{-sigma} A D^{sigma-1})^{-1}.
    Parameters
    ----------
    adj : sp.spmatrix, shape [n, n]
        Sparse adjacency matrix.
    alpha : float
        (1-alpha) is the teleport probability.
    sigma
        Hyper-parameter controlling the propagation style.
        Set sigma=1 to obtain the PPR matrix.
    nodes : np.ndarray, shape [?]
        Nodes for which we want to compute Personalized PageRank.
    Returns
    -------
    prop_matrix : np.ndarray, shape [n, n]
        Propagation matrix.
    """
    n = adj.shape[0]
    deg = adj.sum(1).A1

    deg_min_sig = sp.diags(np.power(deg, -sigma))
    deg_sig_min = sp.diags(np.power(deg, sigma - 1))
    pre_inv = sp.eye(n) - alpha * deg_min_sig @ adj @ deg_sig_min

    # solve for x in: pre_inx @ x = b
    b = np.eye(n)
    if nodes is not None:
        b = b[:, nodes]

    return (1 - alpha) * spl.solve(pre_inv.toarray().T, b).T


def flip_label(lbl, total_classes):
    """
    Flip given label to a random false label
    """
    lbl_ = lbl.copy()
    
    lbl_class = np.argmax(lbl)
    
    possible_classes = [i for i in range(total_classes) if i != lbl_class]
    
    lbl_[lbl_class] = 0.
    lbl_[np.random.choice(possible_classes)] = 1.
    
    return lbl_


def poison_labels_random(labels, train_idx, BUDGET):
    
    labels_ = labels.copy()
    total_classes = labels_.shape[1]
    
    random_ids = np.random.choice(train_idx, size=BUDGET, replace=False)
    
    for idx in random_ids:
        labels_[idx] = flip_label(labels_[idx], total_classes)

    return labels_

# NTK Binary Attack

In [3]:
def ntk_binclass(kernel, train_ids, test_ids, dataset, 
                   poison_budget, CV, verbose=False):   
    
    Y = dataset['labels']
    
    # precompute NTK
    _, ntk_pre = NTK_inference(kernel.cpu(), train_ids, test_ids, torch.tensor(Y)) 
    
    Y_L = Y[train_ids]
    num_classes = Y.shape[1]
    
    class_a, class_b = np.bincount(Y.argmax(1)).argsort()[::-1][:2]
    
    ab_train_ids = []
    
    for idx in range(len(train_ids)):
        label = y_gt[train_ids[idx]]
        if label == class_a or label == class_b:
            ab_train_ids.append(idx) 
    

    # construct Y_L_flipped
    class_a_ids = np.where(y_gt[train_ids] == class_a)[0]
    class_b_ids = np.where(y_gt[train_ids] == class_b)[0]
    
    y_gt_copy = y_gt[train_ids].copy()
    y_gt_copy[class_a_ids] = class_b
    y_gt_copy[class_b_ids] = class_a
    
    Y_L_flipped = np.eye(num_classes)[y_gt_copy]
    
    
    # NTK-binary Objective
    
    # define variables
    H_long = cp.Variable((len(train_ids), 1), boolean=True)
    epsilons = cp.Variable((len(test_ids),1))
    
    #poisoned labels of size (n x c)
    poison_preds = cp.multiply(H_long, Y_L_flipped)
    
    #flip poison preds 
    clean_labels = cp.multiply(1 - H_long, Y_L)
    
    # create poisoned training labels by combining clean labels and poison preds
    Y_poisoned = poison_preds + clean_labels
    
    # predictions of NTK 
    Y_pred_poisoned = ntk_pre @ Y_poisoned
  

    # multiclass objective
    multi_obj = cp.Minimize(cp.sum(cp.multiply(Y_pred_poisoned[test_ids], Y[test_ids])))
    
    # objective-2
    # P_true = cp.sum(cp.multiply(Y_pred_poisoned[test_ids], Y[test_ids]))
    # P_false_max = cp.max(cp.multiply(Y_pred_poisoned[test_ids], 1 - Y[test_ids]), axis=1)
    #multi_obj = cp.Minimize(cp.sum(P_true - P_false_max))
    
    #multi_obj = cp.Minimize(cp.sum(epsilons))
    
    # define constraints
    constraints = [ # budget constraint
                    cp.sum(H_long[ab_train_ids]) == poison_budget, 
                    
                    # Probs. of false ids is greater than true probs.
                    # P_false_max >= P_true
                    ] 
    
    
    # Solve
    prob = cp.Problem(multi_obj, constraints) 
    prob.solve(solver=cp.MOSEK, verbose=False)
    
    predictions_argmax = Y_pred_poisoned.value.argmax(1)
    labels_argmax = Y.argmax(1)
    
    train_acc_lp = accuracy_score(predictions_argmax[train_ids], labels_argmax[train_ids])
    #val_acc_lp = accuracy_score(predictions_argmax[val_ids], labels_argmax[val_ids])
    test_acc_lp = accuracy_score(predictions_argmax[test_ids], labels_argmax[test_ids])
    
    if verbose:
        #print("status:", prob.status)
        print("optimal value", prob.value)
    
        # print stats
        #print("Train Acc: {:.4f}".format(train_acc_lp))
        #print("Val Acc: {:.4f}".format(val_acc_lp))
        print("Test Acc: {:.4f}".format(test_acc_lp))
        #print()
        
    
    # create poisoned labels 
    Y_copy = Y.copy()
    poisoned_labels = Y_poisoned.value.argmax(1)
    Y_copy[train_ids] = 0.
    Y_copy[train_ids, poisoned_labels] = 1.


    return Y_copy, H_long, None, test_acc_lp

# NTK for GCN

In [4]:
# load dataset 
dset = 'pubmed'
dataset = pkl.load(open('../small_val_random/{}.pkl'.format(dset), 'rb'))
X = dataset['X']
y_gt = dataset['labels'].argmax(1)

In [5]:
A = dataset['sym_adj']
A_hat = normalize_adj(A)
pie = A_hat@A_hat #propagation_matrix(A_hat) #use pie from PPNP propagation_matrix(A_hat) 

In [6]:
def kappa_0(u):
    z = torch.zeros((u.shape), dtype=dtype).to(device)
    pi = torch.acos(z)*2
    r = (pi - torch.acos(u)) / pi
    r[r!=r] = 1.0
    return r

def kappa_1(u):
    z = torch.zeros((u.shape), dtype=dtype).to(device)
    pi = torch.acos(z) * 2
    r = (u*(pi - torch.acos(u)) + torch.sqrt(1-u*u))/pi
    r[r!=r] = 1.0
    return r

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

#run after loading data
device = torch.device('cpu')
dtype= torch.float64

a = sparse_mx_to_torch_sparse_tensor(normalize_adj(dataset['sym_adj'])).to(device).to_dense()
feat = torch.FloatTensor(X).to(device)
x = feat @ feat.t()
csigma = 1
kernel = torch.zeros((a.shape), dtype=dtype).to(device)

#set depth of GCN
depth = 2

sig = (a @ x @ a.t())

# ReLU GCN
# compute sigma_n + sigma_(n-1) * SS^T * der_relu_(n-1) + ... + sigma_1 * SS^T (n-1 times) * der_relu(n-1) * ... der_relu(1)
kernel_sub = torch.zeros((depth, a.shape[0], a.shape[1]), dtype=dtype).to(device)
for i in range(depth):
    p = torch.zeros((a.shape), dtype=dtype).to(device)
    diag_sig = torch.diagonal(sig)
    sig_i = p + diag_sig.reshape(1, -1)
    sig_j = p + diag_sig.reshape(-1, 1)
    q = torch.sqrt(sig_i * sig_j)
    u = sig/q
    E = (q * kappa_1(u)) * csigma
    E_der = (kappa_0(u)) * csigma
    kernel_der = (a @ a.t()) * E_der
    kernel_sub[i] += sig * kernel_der

    E = E.float()
    sig = a @ E @ a.t()
#     if args.gcn_skip:
#         if args.skip_form == "gcn":
#             sig = sig + sig_1
#         else:
#             sig = (1-alpha)**2 * sig + alpha**2 * sig_1
    for j in range(i):
        kernel_sub[j] *= kernel_der

kernel += torch.sum(kernel_sub, dim=0)
kernel += sig

In [7]:
def NTK_inference(kernel, train_idx, test_idx, labels):
    # compute f(x)
    #id_t = idx_test[0]
    #id_train = idx_train[-1]+1
    kernel_train = kernel[train_idx,:][:, train_idx]
    labels_train = labels[train_idx].type(torch.double)
    kernel_test = kernel[:][:, train_idx] #test_idx, :
    kernel_inv = torch.pinverse(kernel_train, rcond=1e-8)

    pre_output = kernel_test @ kernel_inv.type(torch.float64) 
    output = pre_output @ labels_train.type(torch.float64)
    
    return output, pre_output.cpu().detach().numpy()

In [8]:
# CV setting
CV = True

# using new LP optimization
poisoned_labels_lp = {'dataset_name': dset}
all_test_accs = []

A_square_full = (A_hat@A_hat)
A_square_X = A_square_full @ X

test_accs = []
for split_no in (range(10)):  #tqdm
    print("split: {:2d}".format(split_no))
    temp_d = {}
    
    if CV:
        train_ids = np.append(dataset['split_{}'.format(split_no)]['train_ids'],
                              dataset['split_{}'.format(split_no)]['val_ids'])
    else:
        train_ids = dataset['split_{}'.format(split_no)]['train_ids']
        
        
    rand_subset = np.random.randint(len(dataset['split_{}'.format(split_no)]['test_ids']), size=100)
    test_ids = dataset['split_{}'.format(split_no)]['test_ids'] #[rand_subset]   
    
    A_square = A_square_full[train_ids] 
    Y_L = dataset['labels'][train_ids]
    lamb = 10
    
    # precompute W = ( (A^2 X).T (A^2 X) )^{-1} (A^2 X).T
    inverse = pinv((A_square@X).T @ A_square@X + lamb*sp.eye(X.shape[1]))
    P = inverse @ (A_square@X).T # @ Y_L inside opt
    
    for poison_per in [5, 10, 15, 20, 30]: #5, 
        
        #numeric labels
        y_gt = dataset['labels'].argmax(1)
        
        print("Poison percentage: {:2d}%".format(poison_per))
        poison_budget = np.ceil(len(train_ids)*(poison_per/100))
        
        # NTK inference on clean labels
#         ntk_out, _ = NTK_inference(kernel.cpu(), train_ids, test_ids, torch.tensor(dataset['labels']))
#         ntk_test_out = ntk_out.argmax(1).cpu().detach().numpy()
#         t_acc = accuracy_score(ntk_test_out[test_ids], y_gt[test_ids])
#         print("Test Acc: {:.2f}".format(t_acc*100))

        y_poison, h_long, _, t_acc = ntk_binclass(kernel, train_ids, test_ids, dataset, 
                                                  poison_budget, CV=CV, verbose=True)
        
        
        
        print("Given Budget: {:2d}   Total flips: {:2d}".format(int(poison_budget), int((y_poison[train_ids].argmax(1) != y_gt[train_ids]).sum())))
        print()
        
        test_accs.append(t_acc)
        
        # log poisoned labels to dict
        if 'split_{}'.format(split_no) in poisoned_labels_lp:
           poisoned_labels_lp['split_{}'.format(split_no)]['{}_percent_poison'.format(poison_per)] = y_poison
        else:
           poisoned_labels_lp['split_{}'.format(split_no)] = {}
           poisoned_labels_lp['split_{}'.format(split_no)]['{}_percent_poison'.format(poison_per)] = y_poison
        
print("Average Test Acc: {:.2f} ({:.2f})".format(np.mean(test_accs)*100, np.std(test_accs)*100))

split:  0
Poison percentage:  5%





optimal value 11964.388718245184
Test Acc: 0.6854
Given Budget:  6   Total flips:  6

Poison percentage: 10%
optimal value 10770.782105802013
Test Acc: 0.5976
Given Budget: 12   Total flips: 12

Poison percentage: 15%
optimal value 9820.48145560111
Test Acc: 0.5002
Given Budget: 18   Total flips: 18

Poison percentage: 20%
optimal value 9102.957725932853
Test Acc: 0.4085
Given Budget: 24   Total flips: 24

Poison percentage: 30%
optimal value 7996.953194667934
Test Acc: 0.3089
Given Budget: 36   Total flips: 36

split:  1
Poison percentage:  5%
optimal value 10945.10806157801
Test Acc: 0.7017
Given Budget:  6   Total flips:  6

Poison percentage: 10%
optimal value 9786.493565961542
Test Acc: 0.6094
Given Budget: 12   Total flips: 12

Poison percentage: 15%
optimal value 9012.107769981007
Test Acc: 0.4904
Given Budget: 18   Total flips: 18

Poison percentage: 20%
optimal value 8364.750298458872
Test Acc: 0.3981
Given Budget: 24   Total flips: 24

Poison percentage: 30%
optimal value 738

In [9]:
#dump SGCbin poisoned labels
with open('./ntk_attack/{}_ntk_cv_poisoned_labels.pkl'.format(dset), 'wb') as handle:
    pkl.dump(poisoned_labels_lp, handle, protocol=4)