# RoboGraph
This is the demo for the submission of paper
>  __Certified Robustness of Graph Convolution Networks for Graph Classification under Topological Attacks__

Before running the demo, please make sure all the required packages are installed.

A detailed instruction is provided in [README.md](./README.md).

In [None]:
import torch
import numpy as np
import os.path as osp
import tempfile
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.data.makedirs import makedirs
from robograph.model.gnn import GC_NET, train, eval
from tqdm.notebook import tqdm
from robograph.utils import process_data, cal_logits

from robograph.attack.admm import admm_solver
from robograph.attack.cvx_env_solver import cvx_env_solver
from robograph.attack.dual import dual_solver
from robograph.attack.greedy_attack import Greedy_Attack
from robograph.attack.utils import calculate_Fc

## Graph classification with linear activation function

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# prepare dataset
ds_name = 'ENZYMES'
path = osp.join(tempfile.gettempdir(), 'data', ds_name)
save_path = osp.join(tempfile.gettempdir(), 'data', ds_name, 'saved')
if not osp.isdir(save_path):
    makedirs(save_path)
dataset = TUDataset(path, name=ds_name, use_node_attr=True)
dataset = dataset.shuffle()
train_size = len(dataset) // 10 * 3
val_size = len(dataset) // 10 * 2
train_dataset = dataset[:train_size]
val_dataset = dataset[train_size: train_size + val_size]
test_dataset = dataset[train_size + val_size:]

# prepare dataloader
train_loader = DataLoader(train_dataset, batch_size=20)
val_loader = DataLoader(val_dataset, batch_size=20)
test_loader = DataLoader(test_dataset, batch_size=20)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create model
model = GC_NET(hidden=64,
              n_features=dataset.num_features,
              n_classes=dataset.num_classes,
              act='linear',
              pool='avg',
              dropout=0.).to(device)

## Training a vanilla model

In [None]:
best=0
for epoch in tqdm(range(200)):
    loss_all = train(model, train_loader)
    train_acc = eval(model, train_loader)
    val_acc = eval(model, val_loader)
    if val_acc >= best:
        best = val_acc
        torch.save(model.state_dict(), osp.join(save_path, "result.pk"))
    
    tqdm.write("epoch {:03d} ".format(epoch+1) + 
              "train_loss {:.4f} ".format(loss_all) +
              "train_acc {:.4f} ".format(train_acc) +
              "val_acc {:.4f} ".format(val_acc))
test_acc = eval(model, test_loader, testing=True, save_path=save_path)
print("test_acc {:.4f}".format(test_acc))

## Robustness certificate

In [None]:
W = model.conv.weight.detach().cpu().numpy().astype(np.float64)
U = model.lin.weight.detach().cpu().numpy().astype(np.float64)

k = dataset.num_classes

# counter of certifiably robust and vulnerable 
robust_dual = 0
robust_cvx = 0
vul_admm = 0
vul_admm_g = 0
vul_greedy = 0

# counter of correct classification
correct = 0

# attacker settings
strength = 3
delta_g = 10

# setting for solvers
dual_params = dict(iter=200, nonsmooth_init='random')
cvx_params = dict(iter=400, lr=0.3, verbose=0, constr='1+2+3', 
                 activation='linear', algo='swapping', nonsmooth_init='subgrad')
admm_params = dict(iter=200, mu=1)

for data in tqdm(test_dataset, desc='across graphs'):
    A, X, y = process_data(data)
    deg = A.sum(1)
    n_nodes = A.shape[0]
    n_edges = np.count_nonzero(A) // 2
    
    delta_l = np.minimum(np.maximum(deg - np.max(deg) + strength, 0), n_nodes - 1).astype(int)
    # delta_g
    
    logits = cal_logits(A, X@W, U, act='linear')
    c_pred = logits.argmax()
    
    if c_pred != y:
        continue
    correct += 1
    fc_vals_orig = [0] * k
    fc_vals_dual = [0] * k
    fc_vals_cvx = [0] * k
    fc_vals_admm = [0] * k
    fc_vals_admm_g = [0] * k
    fc_vals_greedy = [0] * k
    
    
    for c in tqdm(range(k), desc='across labels', leave=False):
        if c == y:
            continue
        u = U[y] - U[c]
        XW = X@W
        
        # fc_val_orig
        fc_vals_orig[c] = calculate_Fc(A, XW, u / n_nodes)
        
        # fc_val_dual
        dual_sol = dual_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **dual_params)
        fc_vals_dual[c] = dual_sol['opt_f']
        
        # fc_val_cvx
        cvx_sol =cvx_env_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **cvx_params)
        fc_vals_cvx[c] = cvx_sol['opt_f']
        
        # fc_val_admm
        admm_params['init_B'] = dual_sol['opt_A']
        admm_sol = admm_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **admm_params)
        fc_vals_admm[c] =  admm_sol['opt_f']
        
        # fc_val_admm_g: admm + greedy
        attack = Greedy_Attack(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g)
        if np.array_equal(admm_sol['opt_A'], admm_sol['opt_A'].T):
            admm_A = admm_sol['opt_A']
        else:
            admm_A = np.minimum(admm_sol['opt_A'], admm_sol['opt_A'].T)
        admm_g_sol = attack.attack(admm_A)  # init from admm
        fc_vals_admm_g[c] = admm_g_sol['opt_f']
        
        # fc_val_greedy
        attack = Greedy_Attack(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g)
        greedy_sol = attack.attack(A)  # init from A
        fc_vals_greedy[c] = greedy_sol['opt_f']
    
    if np.min(fc_vals_dual) >= 0:
        robust_dual += 1
    if np.min(fc_vals_cvx) >= 0:
        robust_cvx += 1
    if np.min(fc_vals_admm) < 0:
        vul_admm += 1
    if np.min(fc_vals_admm_g) < 0:
        vul_admm_g += 1
    if np.min(fc_vals_greedy) < 0:
        vul_greedy += 1

In [None]:
print('dataset {}'.format(ds_name),
      'strength {:02d}'.format(strength),
      'delta_g {:02d}'.format(delta_g),
      'dual {:.2f}'.format(robust_dual / correct),
      'cvx {:.2f}'.format(robust_cvx / correct),
      'admm rate {:.2f}'.format(vul_admm / correct),
      'admm_g rate {:.2f}'.format(vul_admm_g / correct),
      'greedy rate {:.2f}'.format(vul_greedy / correct),)

## Warm start from adversarial sample by greedy method

In [None]:
strength = 3
for idx, data in tqdm(enumerate(train_dataset), desc='adverarial examples'):
    A, X, y = process_data(data)
    deg = A.sum(1)
    n_nodes = A.shape[0]
    delta_l = np.minimum(np.maximum(deg - np.max(deg) + strength, 0), n_nodes - 1).astype(int)
    delta_g = n_nodes * np.max(delta_l)
    logits = cal_logits(A, X@W, U, act='linear')
    c_pred = logits.argmax()
    
    fc_vals_greedy = [0] * k
    fc_A_greedy = [A] * k
    for c in range(k):
        u = U[y] - U[c]
        XW = X@W
        ''' greedy attack '''
        attack = Greedy_Attack(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g)
        greedy_sol = attack.attack(A)  # init from A
        fc_vals_greedy[c] = greedy_sol['opt_f']
        fc_A_greedy[c] = greedy_sol['opt_A']
    pick_idx = np.argmin(fc_vals_greedy)
    train_dataset[idx].edge_index = torch.tensor(fc_A_greedy[pick_idx].nonzero())
torch.save(train_dataset, osp.join(save_path, 'adv_set.pk'))

## Robust linear model

In [None]:
model = GC_NET(hidden=64,
               n_features=dataset.num_features,
               n_classes=dataset.num_classes,
               act='linear',
               pool='avg',
               dropout=0.).to(device)
adv = torch.load(osp.join(save_path, 'adv_set.pk'))
adv_loader = DataLoader(adv + train_dataset, batch_size=20)

best = 0
for epoch in tqdm(range(200), desc='epoch'):
    loss_all =  train(model, train_loader, robust=True, adv_loader=adv_loader, lamb=0.5)
    train_acc = eval(model, train_loader, robust=True)
    val_acc = eval(model, val_loader, robust=True)
    
    if val_acc >= best:
        best = val_acc
        torch.save(model.state_dict(), osp.join(save_path, 'result_robust.pk'))
    #     tqdm.write("epoch {:03d} ".format(epoch+1) + 
    #               "train_loss {:.4f} ".format(loss_all) +
    #               "train_acc {:.4f} ".format(train_acc) +
    #               "val_acc {:.4f} ".format(val_acc))

test_acc = eval(model, test_loader, testing=True, save_path=save_path, robust=True)
print("test_acc {:.4f}".format(test_acc))

## Robustness certificate with robust model

In [None]:
W = model.conv.weight.detach().cpu().numpy().astype(np.float64)
U = model.lin.weight.detach().cpu().numpy().astype(np.float64)

k = dataset.num_classes

# counter of certifiably robust and vulnerable 
robust_dual = 0
robust_cvx = 0
vul_admm = 0
vul_admm_g = 0
vul_greedy = 0

# counter of correct classification
correct = 0

# attacker settings
strength = 3
delta_g = 10

# setting for solvers
dual_params = dict(iter=200, nonsmooth_init='random')
cvx_params = dict(iter=400, lr=0.3, verbose=0, constr='1+2+3', 
                 activation='linear', algo='swapping', nonsmooth_init='subgrad')
admm_params = dict(iter=200, mu=1)

for data in tqdm(test_dataset, desc='across graphs'):
    A, X, y = process_data(data)
    deg = A.sum(1)
    n_nodes = A.shape[0]
    n_edges = np.count_nonzero(A) // 2
    
    delta_l = np.minimum(np.maximum(deg - np.max(deg) + strength, 0), n_nodes - 1).astype(int)
    # delta_g
    
    logits = cal_logits(A, X@W, U, act='linear')
    c_pred = logits.argmax()
    
    if c_pred != y:
        continue
    correct += 1
    fc_vals_orig = [0] * k
    fc_vals_dual = [0] * k
    fc_vals_cvx = [0] * k
    fc_vals_admm = [0] * k
    fc_vals_admm_g = [0] * k
    fc_vals_greedy = [0] * k
    
    
    for c in tqdm(range(k), desc='across labels', leave=False):
        if c == y:
            continue
        u = U[y] - U[c]
        XW = X@W
        
        # fc_val_orig
        fc_vals_orig[c] = calculate_Fc(A, XW, u / n_nodes)
        
        # fc_val_dual
        dual_sol = dual_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **dual_params)
        fc_vals_dual[c] = dual_sol['opt_f']
        
        # fc_val_cvx
        cvx_sol =cvx_env_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **cvx_params)
        fc_vals_cvx[c] = cvx_sol['opt_f']
        
        # fc_val_admm
        admm_params['init_B'] = dual_sol['opt_A']
        admm_sol = admm_solver(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g, **admm_params)
        fc_vals_admm[c] =  admm_sol['opt_f']
        
        # fc_val_admm_g: admm + greedy
        attack = Greedy_Attack(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g)
        if np.array_equal(admm_sol['opt_A'], admm_sol['opt_A'].T):
            admm_A = admm_sol['opt_A']
        else:
            admm_A = np.minimum(admm_sol['opt_A'], admm_sol['opt_A'].T)
        admm_g_sol = attack.attack(admm_A)  # init from admm
        fc_vals_admm_g[c] = admm_g_sol['opt_f']
        
        # fc_val_greedy
        attack = Greedy_Attack(A, XW, u / n_nodes, delta_l=delta_l, delta_g=delta_g)
        greedy_sol = attack.attack(A)  # init from A
        fc_vals_greedy[c] = greedy_sol['opt_f']
    
    if np.min(fc_vals_dual) >= 0:
        robust_dual += 1
    if np.min(fc_vals_cvx) >= 0:
        robust_cvx += 1
    if np.min(fc_vals_admm) < 0:
        vul_admm += 1
    if np.min(fc_vals_admm_g) < 0:
        vul_admm_g += 1
    if np.min(fc_vals_greedy) < 0:
        vul_greedy += 1

In [None]:
print('dataset {}'.format(ds_name),
      'strength {:02d}'.format(strength),
      'delta_g {:02d}'.format(delta_g),
      'dual {:.2f}'.format(robust_dual / correct),
      'cvx {:.2f}'.format(robust_cvx / correct),
      'admm rate {:.2f}'.format(vul_admm / correct),
      'admm_g rate {:.2f}'.format(vul_admm_g / correct),
      'greedy rate {:.2f}'.format(vul_greedy / correct),)