In [2]:
!pip install --upgrade pip
!pip3 install torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
!pip install torch_geometric
!pip install torch_cluster torch_scatter -f https://data.pyg.org/whl/torch-2.3.1+cu121.html
!pip install gdown

Looking in indexes: https://download.pytorch.org/whl/cu121
Looking in links: https://data.pyg.org/whl/torch-2.3.1+cu121.html


In [3]:
!pip -q install --force-reinstall --no-deps fsspec==2023.6.0



In [4]:
!nvcc --version


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import random_split
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.utils import dense_to_sparse
import random
import math
import numpy as np
from typing import List
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
import math, random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device



device(type='cpu')

#Dataset Preprocessing
Doing node-level classification using a k-means synthetic feature

In [166]:
from sklearn.cluster import KMeans
K = 3 #how many clusters to generate
def preprocess_MUTAG(data: Data):
  #create synthetic node-level labels with k-means clustering
  node_labels = data.x.float()  # [N, D]
  N, D = node_labels.shape

  km = KMeans(n_clusters=K, n_init=10, random_state=0).fit(node_labels.numpy())
  y_node = torch.from_numpy(km.labels_).long()
  #add the graph label as a feature to all nodes
  #graph_label = torch.unsqueeze(data.y, dim=0).repeat(N, 1)

  #data.x = torch.cat((node_labels, graph_label.view(-1, 1)), dim=1)
  #substitute in the node-level k-means labels
  data.y = y_node
  return data


In [167]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUD', name='MUTAG', use_node_attr=True,pre_transform = preprocess_MUTAG,force_reload=True)  # 188 graphs
sizes = [data.num_nodes for data in dataset]
idx = int(np.argmax([n if n >= 25 else 0 for n in sizes]))  # pick a larger graph
data = dataset[idx]
print(f"Graph index {idx}: nodes={data.num_nodes}, edges={data.num_edges // 2} (undirected)")

Processing...


Graph index 5: nodes=28, edges=31 (undirected)


Done!


In [168]:
data = dataset[0]
print(data)
print(data.y)
for data in dataset:
  #print(max(data.y))
  pass

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[17])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 1])


In [101]:
"""
dataset = TUDataset(root="data/TUDataset", name="MUTAG", use_node_attr=True)
#MUTAG typically has node attributes; if not present, dataset.data.x may be None
if dataset.num_node_features == 0:
    #fallback: use one-hot of node labels if available, else create constant feature
    if dataset.num_node_labels > 0:
        print("No node features found — using node labels one-hot")
        #replace x by one-hot of node_label (PyG stores node labels in data.x sometimes)
        for data in dataset:
            #convert data.x (assumed scalar) to one-hot classification
            if data.x is not None and data.x.dim() == 1:
                num_cat = int(data.x.max().item()) + 1
                one_hot = F.one_hot(data.x.long(), num_classes=num_cat).to(torch.float)
                data.x = one_hot
            else:
                # create constant feature
                data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)
    else:
        for data in dataset:
            data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)
"""


'\ndataset = TUDataset(root="data/TUDataset", name="MUTAG", use_node_attr=True)\n#MUTAG typically has node attributes; if not present, dataset.data.x may be None\nif dataset.num_node_features == 0:\n    #fallback: use one-hot of node labels if available, else create constant feature\n    if dataset.num_node_labels > 0:\n        print("No node features found — using node labels one-hot")\n        #replace x by one-hot of node_label (PyG stores node labels in data.x sometimes)\n        for data in dataset:\n            #convert data.x (assumed scalar) to one-hot classification\n            if data.x is not None and data.x.dim() == 1:\n                num_cat = int(data.x.max().item()) + 1\n                one_hot = F.one_hot(data.x.long(), num_classes=num_cat).to(torch.float)\n                data.x = one_hot\n            else:\n                # create constant feature\n                data.x = torch.ones((data.num_nodes, 1), dtype=torch.float)\n    else:\n        for data in dataset:\n

# GIN Model
Customizable GINE model built around the pyg GINEConv layer. Use GINE over GIN to train on edge features

In [169]:
class EdgeMLP(nn.Module):
    """Small MLP to combine node and edge features inside each GINE layer."""
    def __init__(self, input_dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class GINEModel(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, hidden_dim=64,
                 num_layers=3, dropout=0.2, num_node_classes=None):
        super().__init__()
        """
        GINE model for node classification
        """
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.edge_transform = nn.ModuleList()
        for i in range(num_layers):
            in_dim = node_input_dim if i == 0 else hidden_dim
            mlp = EdgeMLP(in_dim, hidden_dim, dropout)
            #don't explicitly define edge_dim here; transform edge features into
            #the same dimension as the GINE in the forward pass with edge_layer
            conv = GINEConv(nn=mlp, train_eps=True)
            edge_layer = nn.Linear(edge_input_dim, in_dim)
            self.convs.append(conv)
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            self.edge_transform.append(edge_layer)

        #Per-node classifier head
        if num_node_classes is None:
            #make a best guess from node features
            num_node_classes = node_input_dim
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_node_classes)
        )


    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        if edge_attr is None:
            #if dataset doesn’t have edge attributes, use zeros
            edge_attr = torch.zeros((edge_index.size(1), 1), device=x.device)
        for conv, bn, edge_transform in zip(self.convs, self.bns, self.edge_transform):
            #transform edge features to match node input dimensions
            edge_attr_trans = edge_transform(edge_attr)
            x = conv(x, edge_index, edge_attr_trans)
            x = bn(x)
            x = F.relu(x)
        #g = self.pool(x, batch)
        logits = self.classifier(x)
        return logits


# Training Utility Functions

In [127]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total_nodes = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits = model(data)  # shape [N_nodes_batch, C]
        target = data.y.to(device).view(-1)

        loss = F.cross_entropy(logits, target)
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item()) * logits.size(0)
        preds = logits.argmax(dim=1)
        correct += int((preds == target).sum().item())
        total_nodes += logits.size(0)
    return total_loss / total_nodes, correct / total_nodes


In [128]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total_nodes = 0
    for data in loader:
        data = data.to(device)
        logits = model(data)
        target = data.y.to(device).view(-1)
        loss = F.cross_entropy(logits, target)
        total_loss += float(loss.item()) * logits.size(0)
        preds = logits.argmax(dim=1)
        correct += int((preds == target).sum().item())
        total_nodes += logits.size(0)
    return total_loss / total_nodes, correct / total_nodes





# Baseline Training & Evaluation

In [170]:
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Subset

#hyperparams
train_frac = 0.8 #80/20 train-test split
seed = 42
epochs = 400
batch_size = 32
lr = 1e-3
weight_decay = 1e-5
#initialize random #s
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


indices = list(range(len(dataset)))
#y = list(dataset[i].y for i in range(len(dataset)))
#StratifiedKFold for cross-validation (90/10 train-test split each time)
#skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
np.random.shuffle(indices)
train_idx = indices[:int(len(indices) * train_frac)]
test_idx = indices[:int(len(indices) * train_frac)]
#split with subset
train_subset = Subset(dataset, train_idx)
test_subset = Subset(dataset, test_idx)

In [171]:
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

#initialize model from scratch
model = GINEModel(node_input_dim=dataset[0].x.shape[1],
                  edge_input_dim=dataset[0].edge_attr.shape[1],
                  hidden_dim=32,
                  num_layers=2,
                  dropout=0.2,
                  num_node_classes=K).to(device)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

best_test_acc = 0.0
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, device)
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(model.state_dict(), "best_gine_node.pth")
    if epoch % 5 == 0 or epoch == 1 or epoch == epochs:
        print(f"Epoch {epoch:03d} | Train loss {train_loss:.4f} acc {train_acc:.4f} | Test loss {test_loss:.4f} acc {test_acc:.4f} (best {best_test_acc:.4f})")

print("Finished. Best test node acc:", best_test_acc)
#fold_accuracies.append(best_test_acc)

#mean_acc = np.mean(fold_accuracies)
#std_acc = np.std(fold_accuracies)
#print(f"\n10-fold CV accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

  train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
  test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)


Epoch 001 | Train loss 1.0532 acc 0.5816 | Test loss 1.0855 acc 0.6457 (best 0.6457)
Epoch 005 | Train loss 0.8560 acc 0.7146 | Test loss 1.0064 acc 0.6330 (best 0.6577)
Epoch 010 | Train loss 0.6891 acc 0.7326 | Test loss 0.8721 acc 0.6790 (best 0.6790)
Epoch 015 | Train loss 0.6279 acc 0.7543 | Test loss 0.7584 acc 0.7749 (best 0.7749)
Epoch 020 | Train loss 0.5922 acc 0.7678 | Test loss 0.6252 acc 0.7757 (best 0.7779)
Epoch 025 | Train loss 0.5726 acc 0.7730 | Test loss 0.5894 acc 0.7820 (best 0.7820)
Epoch 030 | Train loss 0.5578 acc 0.7809 | Test loss 0.6436 acc 0.7891 (best 0.7903)
Epoch 035 | Train loss 0.5497 acc 0.7858 | Test loss 0.6109 acc 0.7940 (best 0.7944)
Epoch 040 | Train loss 0.5467 acc 0.7876 | Test loss 0.6447 acc 0.7921 (best 0.7944)
Epoch 045 | Train loss 0.5420 acc 0.7865 | Test loss 0.5775 acc 0.7921 (best 0.7944)
Epoch 050 | Train loss 0.5292 acc 0.7899 | Test loss 0.6000 acc 0.7944 (best 0.7948)
Epoch 055 | Train loss 0.5384 acc 0.7891 | Test loss 0.6116 acc 0

# LINKTELLER Attack

In [172]:
import scipy as sp
class Linkteller():
      def __init__(self, model, device, test_node_feats, test_edge_idx,
                   test_edge_attr = None):
        """
        model: Pretrained model
        device: torch.device
        test_edge_idx: adjacency matrix for the graph being evaluated
        undirected: whether the graph is undirected or not (default: True)
        """
        self.model = model
        self.device = device
        #graph dataset node features
        self.test_node_feats = test_node_feats
        self.num_nodes = test_node_feats.shape[0]
        self.test_edge_idx = test_edge_idx
        self.test_edge_attr = test_edge_attr
        #build adjcency matrix for test graph from edge indices
        self.test_adj = torch.zeros((self.num_nodes, self.num_nodes), dtype=torch.float)
        self.test_adj[test_edge_idx[0], test_edge_idx[1]] = True
        self.test_adj[test_edge_idx[1], test_edge_idx[0]] = True
        self.test_adj.fill_diagonal_(False)
        self.true_edges_undirected = torch.nonzero(torch.triu(self.test_adj,
                                                         diagonal=1),
                                              as_tuple=False)  # [M, 2]
        self.M_true = self.true_edges_undirected.shape[0]
        self.density = self.M_true / (self.num_nodes*(self.num_nodes-1)/2)

      @torch.no_grad()
      def gbb_api(self, node_ids, X_query):
          """
          node_ids: 1D LongTensor of node indices to fetch from output
          X_query: (N, D) full feature matrix Bob provides (Alice uses it with her private edge_index)
          returns: logits[node_ids] shape (len(node_ids), K)

          modified from Linkteller.ipynb
          """
          model.eval()
          #reconstruct graph using Bob's provided node features & Alice's edges
          if self.test_edge_attr is None:
            test_graph = Data(x=X_query, edge_index=self.test_edge_idx)
          else:
            test_graph = Data(x=X_query,
                              edge_index=self.test_edge_idx,
                              edge_attr=self.test_edge_attr)
          test_graph = test_graph.to(device)
          out = model(test_graph)
          return out[node_ids.to(device)].detach().cpu()

      def influence_matrix_for_v(self,v, V_I, X_base, delta=1e-2):
          """
          v: node index (int)
          V_I: 1D LongTensor of nodes-of-interest to score against
          X_base: (N, D) baseline features
          returns: Iv (|V_I|, K) = (P' - P)/delta where rows correspond to u in V_I
          """
          node_ids = V_I
          P = self.gbb_api(node_ids, X_base)

          Xp = X_base.clone()
          Xp[v] = (1.0 + delta) * Xp[v]  # upweight features of v
          Pp = self.gbb_api(node_ids, Xp)

          Iv = (Pp - P) / delta  # finite-diff approximation
          return Iv  # (|V_I|, K)

      def linkteller_scores(self, V_C, X_base, delta=1e-2):
          """
          V_C: nodes-of-interest (attack surface) as 1D LongTensor
          returns: dict {(u,v): score} for u != v, unordered pairs
          """
          V_C = V_C.cpu()
          scores = {}
          for j, v in enumerate(V_C.tolist()):
              # rows aligned with V_C
              Iv = self.influence_matrix_for_v(v,
                                               V_C,
                                               X_base,
                                               delta=delta).numpy()
              # influence value of v on each u = ||Iv[u,:]||_2
              norms = np.linalg.norm(Iv, axis=1)
              for i, u in enumerate(V_C.tolist()):
                  if u == v:
                      continue
                  key = (min(u,v), max(u,v))
                  # symmetrical score: max of v→u and u→v will be handled later; accumulate max
                  scores[key] = max(scores.get(key, 0.0), float(norms[i]))
          return scores



In [173]:
#pick larger graph
sizes = [data.num_nodes for data in dataset]
idx = int(np.argmax([n if n >= 25 else 0 for n in sizes]))  # pick a larger graph

test_graph = dataset[idx]
linkteller_MUTAG_GIN = Linkteller(model = model,
                                  device=device,
                                  test_node_feats=test_graph.x,
                                  test_edge_idx=test_graph.edge_index,
                                  test_edge_attr=test_graph.edge_attr)


In [174]:
# Choose attack node set V_C (we’ll use all nodes to make life easy)
N = test_graph.x.shape[0]
V_C = torch.arange(N, dtype=torch.long)
X = test_graph.x
scores = linkteller_MUTAG_GIN.linkteller_scores(V_C, X, delta=1e-2)

# Turn scores into a sorted list
sorted_pairs = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
len(sorted_pairs), sorted_pairs[:5]


(378,
 [((7, 8), 6.03274393081665),
  ((8, 13), 6.03274393081665),
  ((2, 7), 6.032741069793701),
  ((6, 7), 6.032741069793701),
  ((8, 9), 6.032741069793701)])

In [175]:
n = N
m_true = linkteller_MUTAG_GIN.M_true
m_belief = int(round(linkteller_MUTAG_GIN.density * (n*(n-1)/2)))

pred_edges = set([pair for (pair, _) in sorted_pairs[:m_belief]])

# ground truth undirected edges as set of tuples (i,j) with i<j
true_edges = set([tuple(e.tolist()) for e in linkteller_MUTAG_GIN.true_edges_undirected])

tp = len(pred_edges & true_edges)
fp = len(pred_edges - true_edges)
fn = len(true_edges - pred_edges)

precision = tp / (tp + fp + 1e-12)
recall    = tp / (tp + fn + 1e-12)
f1        = 2*precision*recall / (precision + recall + 1e-12)
print(f"Precision={precision:.3f} | Recall={recall:.3f} | F1={f1:.3f} | m_belief={m_belief} | true M={m_true}")


Precision=0.419 | Recall=0.419 | F1=0.419 | m_belief=31 | true M=31


In [176]:
def evaluate_at_fraction(frac):
    m = int(round(frac * (n*(n-1)/2)))
    pred = set([pair for (pair, _) in sorted_pairs[:m]])
    tp = len(pred & true_edges)
    fp = len(pred - true_edges)
    fn = len(true_edges - pred)
    p = tp / (tp + fp + 1e-12)
    r = tp / (tp + fn + 1e-12)
    f1 = 2*p*r / (p + r + 1e-12)
    return p, r, f1, m
density = linkteller_MUTAG_GIN.density
for frac in [0.5*density, 0.8*density, density, 1.2*density, 1.5*density]:
    p, r, f1, m = evaluate_at_fraction(frac)
    print(f"k_hat={frac:.4f}  m={m:3d}  P={p:.3f} R={r:.3f} F1={f1:.3f}")


k_hat=0.0410  m= 16  P=0.312 R=0.161 F1=0.213
k_hat=0.0656  m= 25  P=0.360 R=0.290 F1=0.321
k_hat=0.0820  m= 31  P=0.419 R=0.419 F1=0.419
k_hat=0.0984  m= 37  P=0.459 R=0.548 F1=0.500
k_hat=0.1230  m= 46  P=0.413 R=0.613 F1=0.494


In [177]:
scores_different_delta = linkteller_MUTAG_GIN.linkteller_scores(V_C, X, delta=5e-3)
sorted_pairs_2 = sorted(scores_different_delta.items(), key=lambda kv: kv[1], reverse=True)
pred_edges_2 = set([pair for (pair, _) in sorted_pairs_2[:m_belief]])

tp2 = len(pred_edges_2 & true_edges)
fp2 = len(pred_edges_2 - true_edges)
fn2 = len(true_edges - pred_edges_2)
p2 = tp2 / (tp2 + fp2 + 1e-12)
r2 = tp2 / (tp2 + fn2 + 1e-12)
f12 = 2*p2*r2 / (p2 + r2 + 1e-12)
print(f"(Δ=5e-3) Precision={p2:.3f} | Recall={r2:.3f} | F1={f12:.3f}")


(Δ=5e-3) Precision=0.355 | Recall=0.355 | F1=0.355
