In [2]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data.dataset import random_split

import GCL.augmentors as A
import GCL.losses as L
from GCL.models import DualBranchContrast

from torch_geometric.nn import GINConv, global_add_pool, summary
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.contrib.nn import PRBCDAttack

from tqdm import tqdm
import itertools
import warnings
import sys
from sklearn.model_selection import StratifiedKFold
import numpy as np
import os.path as osp
import argparse

from attacker.greedy import Greedy
from attacker.PGD import PGDAttack
from wgin_conv import WGINConv

from gin import GIN, LogReg, GCL_classifier, eval_encoder
from MyGCL import Encoder, train, train_classifier

In [3]:
%load_ext autoreload
%autoreload 2

In [None]:
clean = [0.7990, 0.7790, 0.8180, 0.7790, 0.7240]
greedy = [0.7380, 0.6800, 0.7150, 0.6670, 0.6530]
PRBCD = [0.3980, 0.3950, 0.3910, 0.3430, 0.3230]

In [3]:
# Hyperparams
lr = 0.01
num_layers = 3
epochs = 20
seed = 42

device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = TUDataset(path, name='PROTEINS')
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
num_features = max(dataset.num_features, 1)
num_classes = dataset.num_classes
if dataset.num_features==0 :
    print("No node feature, paddings of 1 will be used in GIN when forwarding.")

aug1 = A.Identity()
aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=86),
                       A.EdgeRemoving(pe=0.2),
                       A.NodeDropping(pn=0.2)], 1)

# The graph neural network backbone model to use
torch.manual_seed(seed) # set seed for the reproducibility
gconv = GIN(num_features=num_features, dim=32, num_gc_layers=num_layers, device=device).to(device)
torch.manual_seed(seed) # set seed for the reproducibility
encoder_model = Encoder(encoder=gconv, augmentor=(aug1, aug2)).to(device)
torch.manual_seed(seed) # set seed for the reproducibility
contrast_model = DualBranchContrast(loss=L.InfoNCE(tau=0.2), mode='G2G').to(device)
optimizer = Adam(encoder_model.parameters(), lr=lr)

# Train the encoder with full dataset without labels using contrastive learning
with tqdm(total=20, desc='(T)') as pbar:
    for epoch in range(1, epochs + 1):
        loss = train(encoder_model, contrast_model, dataloader, optimizer)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.27it/s, loss=13.7]


In [4]:
# Split the dataset into two part for training classifier and final evaluation, train_set can be further divided into training and validation parts
torch.manual_seed(seed) # set seed for the reproducibility
train_set, eval_set = random_split(dataset, [0.9, 0.1])
dataloader_train = DataLoader(train_set, batch_size=128, shuffle=True)
dataloader_eval = DataLoader(eval_set, batch_size=128, shuffle=False) # Do not shuffle the evaluation set to make it reproduceable

# Get embeddings for the train_set
encoder_model.eval()
embedding_global, y = encoder_model.encoder.get_embeddings(dataloader_train)

# ====== Train one classifier and do attack =====================================
classifier = train_classifier(embedding_global, y, num_classse=num_classes)

# Put encoder and classifier together, drop the augmentor
encoder_classifier = GCL_classifier(encoder_model.encoder, classifier)

encoder_classifier.eval() # Try to save memory
encoder_classifier.requires_grad_(False) # Try to save memory

GCL_classifier(
  (encoder): GIN(
    (convs): ModuleList(
      (0): WGINConv(nn=Sequential(
        (0): Linear(in_features=3, out_features=32, bias=True)
        (1): ReLU()
        (2): Linear(in_features=32, out_features=32, bias=True)
      ))
      (1-2): 2 x WGINConv(nn=Sequential(
        (0): Linear(in_features=32, out_features=32, bias=True)
        (1): ReLU()
        (2): Linear(in_features=32, out_features=32, bias=True)
      ))
    )
    (bns): ModuleList(
      (0-2): 3 x BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (project): Sequential(
      (0): Linear(in_features=96, out_features=96, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=96, out_features=96, bias=True)
    )
  )
  (classifier): LogReg(
    (fc): Linear(in_features=96, out_features=2, bias=True)
  )
)

In [38]:
from attacker.PRBCD import MyPRBCDAttack as PRBCDAttack


In [42]:
torch.manual_seed(42) # Fix the seed to do fair comparation
prbcd = PRBCDAttack(encoder_classifier, block_size=250_000, lr=2_000, log=False)

In [43]:
# Accuracy on the clean evaluation data
acc_clean, mask = eval_encoder(encoder_classifier, dataloader_eval)

dataloader_eval_adv = prbcd.attack(eval_set, mask, attack_ratio=0.05)

# Accuracy on the adversarial data only
acc_adv_only_PGD, _ = eval_encoder(encoder_classifier, dataloader_eval_adv)

# Overall adversarial accuracy
acc_adv_PGD = acc_clean * acc_adv_only_PGD # T/all * Tadv/T = Tadv/all

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [02:54<00:00,  2.30s/it]


NameError: name 'dataloader_eval_adv_PGD' is not defined

In [45]:
print(acc_clean, acc_adv_PGD)

tensor(0.6847, device='cuda:0') tensor(0.2523, device='cuda:0')


In [47]:
log = 1
verbose = True if log==2 else False
verbose

False

In [8]:
model(one_graph.x, one_graph.edge_index, one_graph.batch)

tensor([[-0.9111, -3.5493]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [9]:
fc_edge_index, edge_weight = attacker.preprocess(one_graph.x, one_graph.edge_index)

In [10]:
fc_edge_index

tensor([[ 0,  0,  0,  ..., 41, 41, 41],
        [ 1,  2,  3,  ..., 38, 39, 40]], device='cuda:0')

In [11]:
edge_weight

tensor([0., 0., 0.,  ..., 0., 0., 1.], device='cuda:0')

In [12]:
model(one_graph.x, fc_edge_index, one_graph.batch, edge_weight=edge_weight)

tensor([[-0.9111, -3.5493]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [13]:
logit = model(one_graph.x, fc_edge_index, one_graph.batch, edge_weight=edge_weight)
loss = F.cross_entropy(logit, one_graph.y)
print(loss.item())

0.06904802471399307


### 上面已经成功地进行了改造，接下来求取对于edge_weight的梯度

In [14]:
n_perturbations = int(one_graph.x.shape[0] * 0.05)
fc_edge_index, modified_weight, adv_edge_index = attacker.attack_one_graph(one_graph.x, one_graph.edge_index, one_graph.batch, one_graph.y, n_perturbations)

In [15]:
adv_edge_index.shape

torch.Size([2, 160])

In [16]:
logit = model(one_graph.x, adv_edge_index, one_graph.batch)
loss = F.cross_entropy(logit, one_graph.y)
print(loss.item())

0.14834341406822205


In [17]:
logit = model(one_graph.x, fc_edge_index, one_graph.batch, edge_weight=modified_weight)
loss = F.cross_entropy(logit, one_graph.y)
print(loss.item())

0.14834341406822205


In [18]:
print(fc_edge_index.shape)
print(adv_edge_index.shape)

torch.Size([2, 1722])
torch.Size([2, 160])


### 已经成功实现了PGD，接下来将整个test dataset进行转换

In [54]:
from torch_geometric.data import Data

In [55]:
attacker = PGDAttack(surrogate=model, device=device)
attack_ratio = 0.05

In [56]:
dataset_eval = dataset[eval_set.indices]

In [57]:
onegraph = dataset_eval[0]
onegraph

Data(edge_index=[2, 286], x=[77, 3], y=[1])

In [52]:
adv_datalist_eval = []

for one_graph in tqdm(dataset_eval):
    one_graph.to(device)
#     print(f'before the attack: {one_graph}')
    n_perturbations = int(one_graph.x.shape[0] * attack_ratio)
    attacker = PGDAttack(surrogate=model, device=device)
    _, _, adv_edge_index = attacker.attack_one_graph(one_graph.x, one_graph.edge_index, one_graph.batch, one_graph.y, n_perturbations)
    new_graph = Data(edge_index=adv_edge_index, x=one_graph.x, y=one_graph.y)
#     print(f'after the attack: {new_graph}')
    adv_datalist_eval.append(new_graph)

assert len(adv_datalist_eval)==len(dataset_eval),"haha"

  1%|█▍                                                                                                                                                         | 1/111 [00:02<04:51,  2.65s/it]



 12%|██████████████████                                                                                                                                        | 13/111 [00:33<04:05,  2.50s/it]



 14%|████████████████████▊                                                                                                                                     | 15/111 [00:45<06:25,  4.02s/it]



 19%|█████████████████████████████▏                                                                                                                            | 21/111 [01:02<04:08,  2.76s/it]



 27%|█████████████████████████████████████████▌                                                                                                                | 30/111 [01:26<03:27,  2.56s/it]



 34%|████████████████████████████████████████████████████▋                                                                                                     | 38/111 [01:48<03:12,  2.64s/it]



 44%|███████████████████████████████████████████████████████████████████▉                                                                                      | 49/111 [02:16<02:42,  2.62s/it]



 51%|███████████████████████████████████████████████████████████████████████████████                                                                           | 57/111 [02:37<02:05,  2.32s/it]



 54%|███████████████████████████████████████████████████████████████████████████████████▏                                                                      | 60/111 [02:47<02:22,  2.80s/it]



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [04:54<00:00,  2.65s/it]


In [59]:
dataloader_eval_adv = DataLoader(adv_datalist_eval, batch_size=128)

### 接下来尝试带入原来的框架

In [63]:
attacker = PGDAttack(surrogate=model)
dataloader_eval_adv = attacker.attack(eval_set, mask)

 12%|███████████████▎                                                                                                                    | 10/86 [00:23<03:00,  2.37s/it]



 50%|██████████████████████████████████████████████████████████████████                                                                  | 43/86 [01:47<01:39,  2.30s/it]



 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊                             | 67/86 [02:47<00:41,  2.18s/it]



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86/86 [03:40<00:00,  2.56s/it]


In [64]:
dataloader_eval_adv

<torch_geometric.loader.dataloader.DataLoader at 0x7feeee46ff10>

In [65]:
# Accuracy on the adversarial data only
acc_adv_only, _ = eval_encoder(model, dataloader_eval_adv)

# Overall adversarial accuracy
acc_adv = acc_clean * acc_adv_only # T/all * Tadv/T = Tadv/all

print(f'(A): clean accuracy={acc_clean:.4f}, adversarial accuracy={acc_adv:.4f}')

(A): clean accuracy=0.7748, adversarial accuracy=0.7568


In [68]:
nnodes=3
torch.triu_indices(row=nnodes, col=nnodes, offset=1)

tensor([[0, 0, 1],
        [1, 2, 2]])