In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data.dataset import random_split

import GCL.augmentors as A
import GCL.losses as L
from GCL.models import DualBranchContrast

from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import dropout_edge, add_random_edge, to_dense_adj

from wgin_conv import WGINConv

from tqdm import tqdm
import itertools
import warnings
import sys
from sklearn.model_selection import StratifiedKFold
import numpy as np
import os.path as osp

In [2]:
dataset_name = 'PROTEINS'
device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = TUDataset(path, name=dataset_name)
one_graph = dataset[0]

adj, features, labels = (to_dense_adj(one_graph.edge_index).squeeze().to(device), one_graph.x.to(device), one_graph.y.to(device))
print(adj.device)
print(features.device)
print(labels.device)

cuda:0
cuda:0
cuda:0


In [3]:
# Prepare the victim model
from gin import *

dataset_name = 'PROTEINS'
train_multiple_classifiers = False

# Hyperparams
lr = 0.01
num_layers = 3
epochs = 20
print(f'======The hyperparams: lr={lr}, num_layers={num_layers}, epochs={epochs}. On dataset:{dataset_name}======')

device = torch.device('cuda')
path = osp.join(osp.expanduser('~'), 'datasets')
dataset = TUDataset(path, name=dataset_name)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
num_features = max(dataset.num_features, 1)
num_classes = dataset.num_classes
if dataset.num_features==0 :
    print("No node feature, paddings of 1 will be used in GIN when forwarding.")


best_hyperparams = {'lr': 0.01, 'num_layer': 5, 'hidden_dim': 32, 'dropout': 0.5, 'batch_size': 128}
lr, num_layer, hidden_dim, dropout, batch_size = best_hyperparams.values()

train_val_set, eval_set = random_split(dataset, [0.9, 0.1])
dataloader_train = DataLoader(train_val_set, batch_size=batch_size, shuffle=True) # Use all train+val to train the final model
dataloader_eval = DataLoader(eval_set, batch_size=128, shuffle=False) # Do not shuffle the evaluation set to make it reproduceable



encoder_model = GIN(num_features=num_features, dim=hidden_dim, num_gc_layers=num_layer, dropout=dropout).to(device)
classifier = LogReg(hidden_dim * num_layer, num_classes).to(device)
model = GCL_classifier(encoder_model, classifier)
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)

# Train the encoder with full dataset without labels using contrastive learning
with tqdm(total=epochs, desc='(T)') as pbar:
    for epoch in range(1, epochs + 1):
        loss = train(model, dataloader_train, optimizer, scheduler)
        pbar.set_postfix({'loss': loss})
        pbar.update()

# Accuracy on the clean evaluation data
acc_clean, mask = eval_encoder(model, dataloader_eval)

# Instantiate an attacker and attack the victim model
attacker = Greedy(pn=0.05)
dataloader_eval_adv = attacker.attack(eval_set, mask)

# Accuracy on the adversarial data only
acc_adv_only, _ = eval_encoder(model, dataloader_eval_adv)

# Overall adversarial accuracy
acc_adv = acc_clean * acc_adv_only # T/all * Tadv/T = Tadv/all

print(f'(A): clean accuracy={acc_clean:.4f}, adversarial accuracy={acc_adv:.4f}')




(T):   0%|                                                                                                                                 | 0/20 [00:00<?, ?it/s]

Shape of out:  torch.Size([5136, 3])
Shape of out:  torch.Size([5136, 32])
Shape of out:  torch.Size([5136, 32])
Shape of out:  torch.Size([5136, 32])
Shape of out:  torch.Size([5136, 32])
Shape of out:  torch.Size([4730, 3])
Shape of out:  torch.Size([4730, 32])
Shape of out:  torch.Size([4730, 32])
Shape of out:  torch.Size([4730, 32])
Shape of out:  torch.Size([4730, 32])
Shape of out:  torch.Size([4714, 3])
Shape of out:  torch.Size([4714, 32])
Shape of out:  torch.Size([4714, 32])
Shape of out:  torch.Size([4714, 32])
Shape of out:  torch.Size([4714, 32])
Shape of out:  torch.Size([4925, 3])
Shape of out:  torch.Size([4925, 32])
Shape of out:  torch.Size([4925, 32])
Shape of out:  torch.Size([4925, 32])
Shape of out:  torch.Size([4925, 32])
Shape of out:  torch.Size([4229, 3])
Shape of out:  torch.Size([4229, 32])
Shape of out:  torch.Size([4229, 32])
Shape of out:  torch.Size([4229, 32])
Shape of out:  torch.Size([4229, 32])
Shape of out:  torch.Size([4552, 3])
Shape of out:  tor

(T):  10%|███████████                                                                                                   | 2/20 [00:00<00:07,  2.57it/s, loss=31.6]

Shape of out:  torch.Size([4892, 3])
Shape of out:  torch.Size([4892, 32])
Shape of out:  torch.Size([4892, 32])
Shape of out:  torch.Size([4892, 32])
Shape of out:  torch.Size([4892, 32])
Shape of out:  torch.Size([4986, 3])
Shape of out:  torch.Size([4986, 32])
Shape of out:  torch.Size([4986, 32])
Shape of out:  torch.Size([4986, 32])
Shape of out:  torch.Size([4986, 32])
Shape of out:  torch.Size([4481, 3])
Shape of out:  torch.Size([4481, 32])
Shape of out:  torch.Size([4481, 32])
Shape of out:  torch.Size([4481, 32])
Shape of out:  torch.Size([4481, 32])
Shape of out:  torch.Size([4660, 3])
Shape of out:  torch.Size([4660, 32])
Shape of out:  torch.Size([4660, 32])
Shape of out:  torch.Size([4660, 32])
Shape of out:  torch.Size([4660, 32])
Shape of out:  torch.Size([4501, 3])
Shape of out:  torch.Size([4501, 32])
Shape of out:  torch.Size([4501, 32])
Shape of out:  torch.Size([4501, 32])
Shape of out:  torch.Size([4501, 32])
Shape of out:  torch.Size([4357, 3])
Shape of out:  tor

(T):  15%|████████████████▊                                                                                               | 3/20 [00:01<00:04,  3.59it/s, loss=20]

Shape of out:  torch.Size([5241, 3])
Shape of out:  torch.Size([5241, 32])
Shape of out:  torch.Size([5241, 32])
Shape of out:  torch.Size([5241, 32])
Shape of out:  torch.Size([5241, 32])
Shape of out:  torch.Size([4619, 3])
Shape of out:  torch.Size([4619, 32])
Shape of out:  torch.Size([4619, 32])
Shape of out:  torch.Size([4619, 32])
Shape of out:  torch.Size([4619, 32])
Shape of out:  torch.Size([4989, 3])
Shape of out:  torch.Size([4989, 32])
Shape of out:  torch.Size([4989, 32])
Shape of out:  torch.Size([4989, 32])
Shape of out:  torch.Size([4989, 32])
Shape of out:  torch.Size([4777, 3])
Shape of out:  torch.Size([4777, 32])
Shape of out:  torch.Size([4777, 32])
Shape of out:  torch.Size([4777, 32])
Shape of out:  torch.Size([4777, 32])
Shape of out:  torch.Size([5099, 3])
Shape of out:  torch.Size([5099, 32])
Shape of out:  torch.Size([5099, 32])
Shape of out:  torch.Size([5099, 32])
Shape of out:  torch.Size([5099, 32])
Shape of out:  torch.Size([3915, 3])
Shape of out:  tor

(T):  25%|███████████████████████████▌                                                                                  | 5/20 [00:01<00:03,  4.99it/s, loss=18.5]

Shape of out:  torch.Size([4630, 3])
Shape of out:  torch.Size([4630, 32])
Shape of out:  torch.Size([4630, 32])
Shape of out:  torch.Size([4630, 32])
Shape of out:  torch.Size([4630, 32])
Shape of out:  torch.Size([4393, 3])
Shape of out:  torch.Size([4393, 32])
Shape of out:  torch.Size([4393, 32])
Shape of out:  torch.Size([4393, 32])
Shape of out:  torch.Size([4393, 32])
Shape of out:  torch.Size([3521, 3])
Shape of out:  torch.Size([3521, 32])
Shape of out:  torch.Size([3521, 32])
Shape of out:  torch.Size([3521, 32])
Shape of out:  torch.Size([3521, 32])
Shape of out:  torch.Size([5268, 3])
Shape of out:  torch.Size([5268, 32])
Shape of out:  torch.Size([5268, 32])
Shape of out:  torch.Size([5268, 32])
Shape of out:  torch.Size([5268, 32])
Shape of out:  torch.Size([6012, 3])
Shape of out:  torch.Size([6012, 32])
Shape of out:  torch.Size([6012, 32])
Shape of out:  torch.Size([6012, 32])
Shape of out:  torch.Size([6012, 32])
Shape of out:  torch.Size([4860, 3])
Shape of out:  tor

(T):  30%|█████████████████████████████████                                                                             | 6/20 [00:01<00:02,  5.47it/s, loss=20.6]

Shape of out:  torch.Size([4634, 3])
Shape of out:  torch.Size([4634, 32])
Shape of out:  torch.Size([4634, 32])
Shape of out:  torch.Size([4634, 32])
Shape of out:  torch.Size([4634, 32])
Shape of out:  torch.Size([5694, 3])
Shape of out:  torch.Size([5694, 32])
Shape of out:  torch.Size([5694, 32])
Shape of out:  torch.Size([5694, 32])
Shape of out:  torch.Size([5694, 32])
Shape of out:  torch.Size([4839, 3])
Shape of out:  torch.Size([4839, 32])
Shape of out:  torch.Size([4839, 32])
Shape of out:  torch.Size([4839, 32])
Shape of out:  torch.Size([4839, 32])
Shape of out:  torch.Size([4674, 3])
Shape of out:  torch.Size([4674, 32])
Shape of out:  torch.Size([4674, 32])
Shape of out:  torch.Size([4674, 32])
Shape of out:  torch.Size([4674, 32])
Shape of out:  torch.Size([4855, 3])
Shape of out:  torch.Size([4855, 32])
Shape of out:  torch.Size([4855, 32])
Shape of out:  torch.Size([4855, 32])
Shape of out:  torch.Size([4855, 32])
Shape of out:  torch.Size([5089, 3])
Shape of out:  tor

(T):  35%|██████████████████████████████████████▌                                                                       | 7/20 [00:01<00:02,  5.83it/s, loss=9.49]

Shape of out:  torch.Size([5365, 3])
Shape of out:  torch.Size([5365, 32])
Shape of out:  torch.Size([5365, 32])
Shape of out:  torch.Size([5365, 32])
Shape of out:  torch.Size([5365, 32])
Shape of out:  torch.Size([4875, 3])
Shape of out:  torch.Size([4875, 32])
Shape of out:  torch.Size([4875, 32])
Shape of out:  torch.Size([4875, 32])
Shape of out:  torch.Size([4875, 32])
Shape of out:  torch.Size([4491, 3])
Shape of out:  torch.Size([4491, 32])
Shape of out:  torch.Size([4491, 32])
Shape of out:  torch.Size([4491, 32])
Shape of out:  torch.Size([4491, 32])
Shape of out:  torch.Size([4914, 3])
Shape of out:  torch.Size([4914, 32])
Shape of out:  torch.Size([4914, 32])
Shape of out:  torch.Size([4914, 32])
Shape of out:  torch.Size([4914, 32])
Shape of out:  torch.Size([3963, 3])
Shape of out:  torch.Size([3963, 32])
Shape of out:  torch.Size([3963, 32])
Shape of out:  torch.Size([3963, 32])
Shape of out:  torch.Size([3963, 32])
Shape of out:  torch.Size([4566, 3])
Shape of out:  tor

(T):  45%|█████████████████████████████████████████████████▌                                                            | 9/20 [00:01<00:01,  5.92it/s, loss=6.24]

Shape of out:  torch.Size([4648, 3])
Shape of out:  torch.Size([4648, 32])
Shape of out:  torch.Size([4648, 32])
Shape of out:  torch.Size([4648, 32])
Shape of out:  torch.Size([4648, 32])
Shape of out:  torch.Size([4000, 3])
Shape of out:  torch.Size([4000, 32])
Shape of out:  torch.Size([4000, 32])
Shape of out:  torch.Size([4000, 32])
Shape of out:  torch.Size([4000, 32])
Shape of out:  torch.Size([4295, 3])
Shape of out:  torch.Size([4295, 32])
Shape of out:  torch.Size([4295, 32])
Shape of out:  torch.Size([4295, 32])
Shape of out:  torch.Size([4295, 32])
Shape of out:  torch.Size([5189, 3])
Shape of out:  torch.Size([5189, 32])
Shape of out:  torch.Size([5189, 32])
Shape of out:  torch.Size([5189, 32])
Shape of out:  torch.Size([5189, 32])
Shape of out:  torch.Size([5342, 3])
Shape of out:  torch.Size([5342, 32])
Shape of out:  torch.Size([5342, 32])
Shape of out:  torch.Size([5342, 32])
Shape of out:  torch.Size([5342, 32])
Shape of out:  torch.Size([5703, 3])
Shape of out:  tor

(T):  50%|██████████████████████████████████████████████████████▌                                                      | 10/20 [00:02<00:01,  5.96it/s, loss=7.87]

Shape of out:  torch.Size([4401, 3])
Shape of out:  torch.Size([4401, 32])
Shape of out:  torch.Size([4401, 32])
Shape of out:  torch.Size([4401, 32])
Shape of out:  torch.Size([4401, 32])
Shape of out:  torch.Size([5089, 3])
Shape of out:  torch.Size([5089, 32])
Shape of out:  torch.Size([5089, 32])
Shape of out:  torch.Size([5089, 32])
Shape of out:  torch.Size([5089, 32])
Shape of out:  torch.Size([5147, 3])
Shape of out:  torch.Size([5147, 32])
Shape of out:  torch.Size([5147, 32])
Shape of out:  torch.Size([5147, 32])
Shape of out:  torch.Size([5147, 32])
Shape of out:  torch.Size([3897, 3])
Shape of out:  torch.Size([3897, 32])
Shape of out:  torch.Size([3897, 32])
Shape of out:  torch.Size([3897, 32])
Shape of out:  torch.Size([3897, 32])
Shape of out:  torch.Size([5083, 3])
Shape of out:  torch.Size([5083, 32])
Shape of out:  torch.Size([5083, 32])
Shape of out:  torch.Size([5083, 32])
Shape of out:  torch.Size([5083, 32])
Shape of out:  torch.Size([5903, 3])
Shape of out:  tor

(T):  60%|█████████████████████████████████████████████████████████████████▍                                           | 12/20 [00:02<00:01,  5.01it/s, loss=6.57]

Shape of out:  torch.Size([4293, 3])
Shape of out:  torch.Size([4293, 32])
Shape of out:  torch.Size([4293, 32])
Shape of out:  torch.Size([4293, 32])
Shape of out:  torch.Size([4293, 32])
Shape of out:  torch.Size([3998, 3])
Shape of out:  torch.Size([3998, 32])
Shape of out:  torch.Size([3998, 32])
Shape of out:  torch.Size([3998, 32])
Shape of out:  torch.Size([3998, 32])
Shape of out:  torch.Size([4007, 3])
Shape of out:  torch.Size([4007, 32])
Shape of out:  torch.Size([4007, 32])
Shape of out:  torch.Size([4007, 32])
Shape of out:  torch.Size([4007, 32])
Shape of out:  torch.Size([6614, 3])
Shape of out:  torch.Size([6614, 32])
Shape of out:  torch.Size([6614, 32])
Shape of out:  torch.Size([6614, 32])
Shape of out:  torch.Size([6614, 32])
Shape of out:  torch.Size([4801, 3])
Shape of out:  torch.Size([4801, 32])
Shape of out:  torch.Size([4801, 32])
Shape of out:  torch.Size([4801, 32])
Shape of out:  torch.Size([4801, 32])
Shape of out:  torch.Size([5455, 3])
Shape of out:  tor

(T):  65%|██████████████████████████████████████████████████████████████████████▊                                      | 13/20 [00:02<00:01,  5.41it/s, loss=6.13]

Shape of out:  torch.Size([4531, 3])
Shape of out:  torch.Size([4531, 32])
Shape of out:  torch.Size([4531, 32])
Shape of out:  torch.Size([4531, 32])
Shape of out:  torch.Size([4531, 32])
Shape of out:  torch.Size([5039, 3])
Shape of out:  torch.Size([5039, 32])
Shape of out:  torch.Size([5039, 32])
Shape of out:  torch.Size([5039, 32])
Shape of out:  torch.Size([5039, 32])
Shape of out:  torch.Size([5753, 3])
Shape of out:  torch.Size([5753, 32])
Shape of out:  torch.Size([5753, 32])
Shape of out:  torch.Size([5753, 32])
Shape of out:  torch.Size([5753, 32])
Shape of out:  torch.Size([3888, 3])
Shape of out:  torch.Size([3888, 32])
Shape of out:  torch.Size([3888, 32])
Shape of out:  torch.Size([3888, 32])
Shape of out:  torch.Size([3888, 32])
Shape of out:  torch.Size([5127, 3])
Shape of out:  torch.Size([5127, 32])
Shape of out:  torch.Size([5127, 32])
Shape of out:  torch.Size([5127, 32])
Shape of out:  torch.Size([5127, 32])
Shape of out:  torch.Size([4073, 3])
Shape of out:  tor

(T):  75%|█████████████████████████████████████████████████████████████████████████████████▊                           | 15/20 [00:03<00:00,  6.05it/s, loss=7.84]

Shape of out:  torch.Size([5815, 3])
Shape of out:  torch.Size([5815, 32])
Shape of out:  torch.Size([5815, 32])
Shape of out:  torch.Size([5815, 32])
Shape of out:  torch.Size([5815, 32])
Shape of out:  torch.Size([4069, 3])
Shape of out:  torch.Size([4069, 32])
Shape of out:  torch.Size([4069, 32])
Shape of out:  torch.Size([4069, 32])
Shape of out:  torch.Size([4069, 32])
Shape of out:  torch.Size([3939, 3])
Shape of out:  torch.Size([3939, 32])
Shape of out:  torch.Size([3939, 32])
Shape of out:  torch.Size([3939, 32])
Shape of out:  torch.Size([3939, 32])
Shape of out:  torch.Size([5167, 3])
Shape of out:  torch.Size([5167, 32])
Shape of out:  torch.Size([5167, 32])
Shape of out:  torch.Size([5167, 32])
Shape of out:  torch.Size([5167, 32])
Shape of out:  torch.Size([4611, 3])
Shape of out:  torch.Size([4611, 32])
Shape of out:  torch.Size([4611, 32])
Shape of out:  torch.Size([4611, 32])
Shape of out:  torch.Size([4611, 32])
Shape of out:  torch.Size([4937, 3])
Shape of out:  tor

(T):  80%|███████████████████████████████████████████████████████████████████████████████████████▏                     | 16/20 [00:03<00:00,  6.25it/s, loss=6.33]

Shape of out:  torch.Size([4895, 3])
Shape of out:  torch.Size([4895, 32])
Shape of out:  torch.Size([4895, 32])
Shape of out:  torch.Size([4895, 32])
Shape of out:  torch.Size([4895, 32])
Shape of out:  torch.Size([4857, 3])
Shape of out:  torch.Size([4857, 32])
Shape of out:  torch.Size([4857, 32])
Shape of out:  torch.Size([4857, 32])
Shape of out:  torch.Size([4857, 32])
Shape of out:  torch.Size([5361, 3])
Shape of out:  torch.Size([5361, 32])
Shape of out:  torch.Size([5361, 32])
Shape of out:  torch.Size([5361, 32])
Shape of out:  torch.Size([5361, 32])
Shape of out:  torch.Size([4650, 3])
Shape of out:  torch.Size([4650, 32])
Shape of out:  torch.Size([4650, 32])
Shape of out:  torch.Size([4650, 32])
Shape of out:  torch.Size([4650, 32])
Shape of out:  torch.Size([5159, 3])
Shape of out:  torch.Size([5159, 32])
Shape of out:  torch.Size([5159, 32])
Shape of out:  torch.Size([5159, 32])
Shape of out:  torch.Size([5159, 32])
Shape of out:  torch.Size([4892, 3])
Shape of out:  tor

(T):  85%|████████████████████████████████████████████████████████████████████████████████████████████▋                | 17/20 [00:03<00:00,  6.40it/s, loss=7.17]

Shape of out:  torch.Size([4785, 32])
Shape of out:  torch.Size([4785, 32])
Shape of out:  torch.Size([4785, 32])
Shape of out:  torch.Size([4785, 32])
Shape of out:  torch.Size([4698, 3])
Shape of out:  torch.Size([4698, 32])
Shape of out:  torch.Size([4698, 32])
Shape of out:  torch.Size([4698, 32])
Shape of out:  torch.Size([4698, 32])
Shape of out:  torch.Size([5170, 3])
Shape of out:  torch.Size([5170, 32])
Shape of out:  torch.Size([5170, 32])
Shape of out:  torch.Size([5170, 32])
Shape of out:  torch.Size([5170, 32])
Shape of out:  torch.Size([4353, 3])
Shape of out:  torch.Size([4353, 32])
Shape of out:  torch.Size([4353, 32])
Shape of out:  torch.Size([4353, 32])
Shape of out:  torch.Size([4353, 32])
Shape of out:  torch.Size([4574, 3])
Shape of out:  torch.Size([4574, 32])
Shape of out:  torch.Size([4574, 32])
Shape of out:  torch.Size([4574, 32])
Shape of out:  torch.Size([4574, 32])
Shape of out:  torch.Size([4881, 3])
Shape of out:  torch.Size([4881, 32])
Shape of out:  to

(T):  95%|███████████████████████████████████████████████████████████████████████████████████████████████████████▌     | 19/20 [00:03<00:00,  6.30it/s, loss=5.26]

Shape of out:  torch.Size([5024, 32])
Shape of out:  torch.Size([3804, 3])
Shape of out:  torch.Size([3804, 32])
Shape of out:  torch.Size([3804, 32])
Shape of out:  torch.Size([3804, 32])
Shape of out:  torch.Size([3804, 32])
Shape of out:  torch.Size([4247, 3])
Shape of out:  torch.Size([4247, 32])
Shape of out:  torch.Size([4247, 32])
Shape of out:  torch.Size([4247, 32])
Shape of out:  torch.Size([4247, 32])
Shape of out:  torch.Size([4878, 3])
Shape of out:  torch.Size([4878, 32])
Shape of out:  torch.Size([4878, 32])
Shape of out:  torch.Size([4878, 32])
Shape of out:  torch.Size([4878, 32])
Shape of out:  torch.Size([5828, 3])
Shape of out:  torch.Size([5828, 32])
Shape of out:  torch.Size([5828, 32])
Shape of out:  torch.Size([5828, 32])
Shape of out:  torch.Size([5828, 32])
Shape of out:  torch.Size([4385, 3])
Shape of out:  torch.Size([4385, 32])
Shape of out:  torch.Size([4385, 32])
Shape of out:  torch.Size([4385, 32])
Shape of out:  torch.Size([4385, 32])
Shape of out:  to

(T): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:03<00:00,  5.20it/s, loss=5.25]


Shape of out:  torch.Size([4872, 3])
Shape of out:  torch.Size([4872, 32])
Shape of out:  torch.Size([4872, 32])
Shape of out:  torch.Size([4872, 32])
Shape of out:  torch.Size([4872, 32])
Shape of out:  torch.Size([4732, 3])
Shape of out:  torch.Size([4732, 32])
Shape of out:  torch.Size([4732, 32])
Shape of out:  torch.Size([4732, 32])
Shape of out:  torch.Size([4732, 32])
Shape of out:  torch.Size([4860, 3])
Shape of out:  torch.Size([4860, 32])
Shape of out:  torch.Size([4860, 32])
Shape of out:  torch.Size([4860, 32])
Shape of out:  torch.Size([4860, 32])
Shape of out:  torch.Size([4621, 3])
Shape of out:  torch.Size([4621, 32])
Shape of out:  torch.Size([4621, 32])
Shape of out:  torch.Size([4621, 32])
Shape of out:  torch.Size([4621, 32])
Shape of out:  torch.Size([5470, 3])
Shape of out:  torch.Size([5470, 32])
Shape of out:  torch.Size([5470, 32])
Shape of out:  torch.Size([5470, 32])
Shape of out:  torch.Size([5470, 32])
Shape of out:  torch.Size([5749, 3])
Shape of out:  tor

Shape of out:  torch.Size([42, 3])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([13, 3])
Shape of out:  torch.Size([13, 32])
Shape of out:  torch.Size([13, 32])
Shape of out:  torch.Size([13, 32])
Shape of out:  torch.Size([13, 32])
Shape of out:  torch.Size([77, 3])
Shape of out:  torch.Size([77, 32])
Shape of out:  torch.Size([77, 32])
Shape of out:  torch.Size([77, 32])
Shape of out:  torch.Size([77, 32])
Shape of out:  torch.Size([20, 3])
Shape of out:  torch.Size([20, 32])
Shape of out:  torch.Size([20, 32])
Shape of out:  torch.Size([20, 32])
Shape of out:  torch.Size([20, 32])
Shape of out:  torch.Size([37, 3])
Shape of out:  torch.Size([37, 32])
Shape of out:  torch.Size([37, 32])
Shape of out:  torch.Size([37, 32])
Shape of out:  torch.Size([37, 32])
Shape of out:  torch.Size([34, 3])
Shape of out:  torch.Size([34, 32])
Shape of out:  torch.Size([34, 32]

In [6]:
import torch_geometric

In [7]:
one_graph.to(device)

Data(edge_index=[2, 162], x=[42, 3], y=[1])

In [8]:
model(one_graph.x, one_graph.edge_index, one_graph.batch)

Shape of out:  torch.Size([42, 3])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])
Shape of out:  torch.Size([42, 32])


tensor([[5.8821, 4.3928]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [9]:
def preprocess(graph, device='cuda'):
    """输入一张图，返回它的全连接版本，只不过对应的不存在的边权重为0"""
    num_nodes = graph.x.shape[0]
    adj = torch_geometric.utils.to_dense_adj(edge_index=graph.edge_index, max_num_nodes=num_nodes).to(device)
    fc_adj = torch.ones_like(adj).to(device) - torch.eye(num_nodes).to(device)
    cplmt_adj = fc_adj - adj
    cplmt_edge_index, _ = torch_geometric.utils.dense_to_sparse(cplmt_adj)
    
    fc_edge_index = torch.cat((graph.edge_index, cplmt_edge_index), dim=1).to(device)
    
    edge_weight = torch.zeros(int(num_nodes*(num_nodes-1)), dtype=float).to(device)
    edge_weight[:graph.edge_index.shape[1]*2] = 1

    return fc_edge_index, edge_weight

In [10]:
fc_edge_index, edge_weight = preprocess(one_graph)

In [14]:
edge_weight.shape

torch.Size([1722])

In [15]:
model(one_graph.x, fc_edge_index, one_graph.batch, edge_weight=edge_weight.unsqueeze(0))

Shape of out:  torch.Size([42, 3])


RuntimeError: mat1 and mat2 must have the same dtype

出现了一个非常奇怪的现象。一旦加入了edge_weight之后，就会出现‘RuntimeError: mat1 and mat2 must have the same dtype’。但是如果把out的形状打出来，就会发现明明没有变过，和把edge_weight=None是一样的。不能理解这个问题出现的原因。

In [31]:
# Do the attack
from attacker.PGD import PGDAttack

attacker = PGDAttack(model=encoder_classifier, nnodes=adj.shape[0], loss_type='CE', device=device).to(device)

In [28]:
adj

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 1., 0.]], device='cuda:0')

device(type='cpu')