# Load Cora Dataset 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cpu')
import numpy as np
import random

data = torch.load('data.pth')
g = data['g'].to(device)
feat = data['feat'].to(device)
label = data['label'].to(device)
train_nodes = data['train_nodes']
val_nodes = data['val_nodes']
test_nodes = data['test_nodes']

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(20)

Using backend: pytorch


# Load victim GCN model

In [2]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    """Graph Convolution Network (GCN)

    Example
    -------
    # GCN with one hidden layer
    >>> model = GCN(100, 10, hid=32)
    """
    def __init__(self,
                 in_feats: int,
                 out_feats: int,
                 hid: list = 16,
                 dropout: float = 0.5):
        super().__init__()
        self.conv1 = GraphConv(in_feats, hid)
        self.conv2 = GraphConv(hid, out_feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat):

        if torch.is_tensor(g):
            feat = self.dropout(feat)
            feat = g @ (feat @ self.conv1.weight) + self.conv1.bias
            feat = F.relu(feat)
            feat = self.dropout(feat)
            feat = g @ (feat @ self.conv2.weight) + self.conv2.bias
            return feat
        
        g = g.add_self_loop()
        feat = self.dropout(feat)
        feat = self.conv1(g, feat)
        feat = F.relu(feat)
        feat = self.dropout(feat)
        feat = self.conv2(g, feat)
        return feat
    
device = torch.device('cpu')

num_feats = feat.size(1)
num_classes = int(label.max() + 1)
model = GCN(num_feats, num_classes).to(device)

model.load_state_dict(torch.load('model.pth', map_location=device))
model.eval()

GCN(
  (conv1): GraphConv(in=1433, out=16, normalization=both, activation=None)
  (conv2): GraphConv(in=16, out=7, normalization=both, activation=None)
  (dropout): Dropout(p=0.5, inplace=False)
)

# Jaccard Similarity based defense 

*Wu et al.* [📝Adversarial Examples on Graph Data: Deep Insights into Attack and Defense](https://arxiv.org/abs/1903.01610), *IJCAI'19*

In [3]:
import dgl
import scipy.sparse as sp
import torch

class JaccardPurification(torch.nn.Module):

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward(self, g, feat):

        g = g.local_var()
        row, col = g.edges()
        A = feat[row]
        B = feat[col]
        score = jaccard_similarity(A, B)
        deg = g.in_degrees()

        condition = score <= self.threshold

        e_id = torch.where(condition)[0]
        g.remove_edges(e_id)

        self.edges = torch.stack([row[e_id], col[e_id]], dim=0)
        return g

    def extra_repr(self) -> str:
        return f"threshold={self.threshold}, allow_singleton={self.allow_singleton}"


def jaccard_similarity(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    intersection = torch.count_nonzero(A * B, axis=1)
    J = intersection * 1.0 / (torch.count_nonzero(A, dim=1) + torch.count_nonzero(B, dim=1) + intersection + 1e-7)
    return J



In [4]:
defense_g = JaccardPurification()(g, feat)

In [5]:
g

Graph(num_nodes=2485, num_edges=10138,
      ndata_schemes={}
      edata_schemes={})

In [6]:
defense_g

Graph(num_nodes=2485, num_edges=9042,
      ndata_schemes={}
      edata_schemes={})

# Evaluation

In the following, you can conduct any attack to obtain a perturbed graph `attack_g`. 
To resist adversarial attacks, the graph can be puritied by using
```python
defense_g = JaccardPurification()(attack_g, feat)
```

and the defensed graph can be used for any downstream tasks

In [7]:
d = torch.load('attack_graph.pth')
attack_g = d['attack_g']
attack_feat = d['attack_feat']

## defense
defense_g = JaccardPurification()(attack_g, feat)

In [8]:
target = 1
target_label = label[target]
print("target label: ", target_label)

target label:  tensor(2)


In [9]:
# with raw graph
model(g, feat)[target].argmax()

tensor(2)

In [10]:
# with attacked graph: target gets misclassified
model(attack_g, attack_feat)[target].argmax()

tensor(1)

In [11]:
# with data defense
model(defense_g, attack_feat)[target].argmax()

tensor(2)