In [1]:
from robust_gcn_structure.certification import certify
from robust_gcn_structure.utils import load_npz
from matplotlib import pyplot as plt
import torch

In [2]:
dataset = "citeseer"
robust_gcn = False  # Whether to load weights for GCN trained with the approach by [Zügner and Günnemann 2019

local_budget = 3
global_budget = 5

In [3]:
target_node = 3311
eval_class = None #0

In [4]:
solver = "ECOS"
max_iters = 250
tolerance = 1e-2
kwargs = {
    'tolerance': tolerance,
    'max_iter': max_iters
}

In [5]:
A, X, z = load_npz(f'../datasets/{dataset}.npz')
A = A + A.T
A[A > 1] = 1
A.setdiag(0)

X = (X>0).astype("float32")
z = z.astype("int64")
N, D = X.shape

weight_path = f"../pretrained_weights/{dataset}"
if robust_gcn:
    weight_path = f"{weight_path}_robust_gcn.pkl"
else:
    weight_path = f"{weight_path}_gcn.pkl"

state_dict = torch.load(weight_path, map_location="cpu")

weights = [v for k,v in state_dict.items() if "weight" in k and "conv" in k]
biases = [v for k,v in state_dict.items() if "bias" in k and "conv" in k]

W1, W2 = [w.cpu().detach().numpy() for w in weights]
b1, b2 = [b.cpu().detach().numpy() for b in biases]

shapes = [x.shape[0] for x in biases]
num_hidden = len(shapes) - 1
if num_hidden > 1:
    raise NotImplementedError("Only one hidden layer is supported.")

weight_list = [W1, b1, W2, b2]
# info_dict = {}

results = certify(target_node, A, X, weight_list, z,
                  local_changes=local_budget,
                  global_changes=global_budget,
                  solver=solver, eval_class=eval_class,
                  use_predicted_class=True,
#                   info_dict=info_dict, 
                  **kwargs)


  self._set_arrayXarray(i, j, x)


In [20]:
import torch as th
import numpy as np

def gcn_forward(A_hat, X, weights, i=None):
    W1, b1, W2, b2 = weights
    l1 = Linear(W1.shape[0], W1.shape[1], bias=True)
    l2 = Linear(W2.shape[0], W2.shape[1], bias=True)
    abs_ahat = Linear(A_hat.shape[0], A_hat.shape[1], bias=False)

    W1 = th.from_numpy(W1)
    b1 = th.from_numpy(b1)
    W2 = th.from_numpy(W2)
    b2 = th.from_numpy(b2)
    A_hat = A_hat.tocoo()
    A_hat = th.sparse.DoubleTensor(th.LongTensor([A_hat.row.tolist(), A_hat.col.tolist()]),
                                   th.DoubleTensor(A_hat.data.astype(np.int32)))

    l1.weight.data = W1
    l1.bias.data = b1
    l2.weight.data = W2
    l2.bias.data = b2
    abs_ahat.weight.data = A_hat

    l1_out = th.relu(abs_ahat(l1(X)))
    logits = abs_ahat(l2(l1_out))

    if i is not None:
        logits = logits[i]

    return logits

gcn_forward(A, X, weight_list)

NameError: name 'Linear' is not defined

In [6]:
results

{'all_robust': True,
 1: {},
 5: {'robust': True,
  'best_uppers': [4.957600847403593],
  'best_lowers': [1.024755233774429],
  'An_pert': <7x24 sparse matrix of type '<class 'numpy.float64'>'
  	with 168 stored elements in Compressed Sparse Row format>,
  'logit_diff_before': 7.977204183916103,
  'solve_times': [0.041622505]},
 2: {'robust': True,
  'best_uppers': [4.676333493537031, 4.0924052085308285],
  'best_lowers': [-0.36070931688405894, 0.39936849497546323],
  'An_pert': <7x24 sparse matrix of type '<class 'numpy.float64'>'
  	with 168 stored elements in Compressed Sparse Row format>,
  'logit_diff_before': 8.001205324903085,
  'solve_times': [0.052148657,
   0.041674059,
   0.034963481,
   0.054064643,
   0.050784099]},
 4: {'robust': True,
  'best_uppers': [3.388207887405914],
  'best_lowers': [0.1493721788140856],
  'An_pert': <7x24 sparse matrix of type '<class 'numpy.float64'>'
  	with 168 stored elements in Compressed Sparse Row format>,
  'logit_diff_before': 8.614593765

In [8]:
if results['robust'] == True:
    print(f"Robustness for node {target_node} and class {eval_class} successfully certified.")
else:
    print(f"Robustness for node {target_node} and class {eval_class} could not be certified.")

KeyError: 'robust'

In [None]:
plt.plot(results['best_lowers'], label="lower bound")
plt.plot(results['best_uppers'], label="upper bound")
plt.plot((0,len(results['best_uppers'])-1), (0,0), color="black", linestyle="--")
plt.legend()
plt.show()