In [1]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html


In [2]:
!pip install torchmetrics
!pip install opacus



In [3]:
import torch
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
import torch_geometric
from torch_geometric.nn import global_mean_pool, global_max_pool


class FingerprintsModel(torch.nn.Module):
    def __init__(self, hidden_channels, dataset, model_type:str):
        super(FingerprintsModel, self).__init__()
        self.model_type = model_type
        self.hidden_channels = hidden_channels

        if self.model_type == "GCN":
            self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, hidden_channels*2)
            self.conv3 = GCNConv(hidden_channels*2, hidden_channels*4)
            # self.conv4 = GCNConv(hidden_channels*4, hidden_channels*8)
        elif self.model_type == "GraphSAGE":
            self.conv1 = SAGEConv(dataset.num_node_features, hidden_channels)
            self.conv2 = SAGEConv(hidden_channels, hidden_channels*2)
            self.conv3 = SAGEConv(hidden_channels*2, hidden_channels*4)
            # self.conv4 = SAGEConv(hidden_channels*4, hidden_channels*8)
        elif self.model_type == "GAT":
            self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
            self.conv2 = GATConv(hidden_channels, hidden_channels*2)
            self.conv3 = GATConv(hidden_channels*2, hidden_channels*4)
            # self.conv4 = GATConv(hidden_channels*4, hidden_channels*8)

        self.lin = torch.nn.Sequential(torch.nn.Linear(hidden_channels*4, 256), torch.nn.ReLU(), torch.nn.Linear(256, dataset.num_tasks))
        self.norm = torch_geometric.nn.InstanceNorm(1, affine=True)

    def forward(self, x, edge_index, batch):
        x = self.norm(x, batch)
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = x.relu()

        # 2. Readout layer
        x = global_max_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = self.lin(x)

        return x

In [4]:
import numpy as np
import torch
from tqdm import tqdm
from functorch import jacrev
import torch_geometric
from functorch import make_functional_with_buffers
from opacus.accountants.utils import get_noise_multiplier
from opacus.optimizers import DPOptimizer
from opacus.utils.batch_memory_manager import wrap_data_loader


def train(model, train_loader, optimizer, criterion, device):
    model.train()

    correct = 0
    epoch_loss = 0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y.squeeze())
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        pred = out.cpu().argmax(dim=1)
        correct += int((pred == data.y.squeeze().cpu()).sum())

    return epoch_loss/len(train_loader), correct/len(train_loader.dataset)


def compute_loss(params, buffers, data_x, data_edge_index, data_batch, targets, fmodel, loss_fn):
    predictions = fmodel(params, buffers, data_x, data_edge_index, data_batch)
    loss = loss_fn(predictions.squeeze(), targets.squeeze())
    return loss


# functorch implementation of per sample gradients needed for DP
compute_per_sample_grads = jacrev(compute_loss)


def train_dp(fmodel, params, buffers, train_loader, device, optimizer, criterion, scheduler=None):
    epoch_losses = []
    correct = 0

    for step, data in enumerate(tqdm(train_loader, desc="Iteration")):
        optimizer.zero_grad(True)
        data = data.to(device)
        out = fmodel(params, buffers, data.x.float(), data.edge_index, data.batch)
        pred = out.cpu().argmax(dim=1)
        correct += int((pred == data.y.squeeze().cpu()).sum())

        if isinstance(criterion, torch.nn.CrossEntropyLoss):
            loss = criterion(out.squeeze(), data.y.squeeze())
        else:
            loss = criterion(out.squeeze(), data.y.squeeze().float().to(device))

        per_sample_grads = compute_per_sample_grads(
            params,
            buffers,
            data.x,
            data.edge_index,
            data.batch,
            data.y,
            fmodel,
            criterion,
        )

        for param, grad_sample in zip(params, per_sample_grads):
            param.grad_sample = grad_sample
            param.grad = (grad_sample.mean(0))

        optimizer.step()
        epoch_losses.append(torch.mean(loss.detach().cpu()))

        if scheduler is not None:
            scheduler.step()

    acc = correct/len(train_loader.dataset)
    return np.mean(epoch_losses), acc, params


def test(model, test_loader, criterion, device):
    model.eval()
    epoch_losses = []

    correct = 0
    for data in tqdm(test_loader):
        data = data.to(device)
        y = data.y.squeeze()
        out = model(data.x, data.edge_index, data.batch)
        pred = out.cpu().argmax(dim=1)
        correct += int((pred == data.y.squeeze().cpu()).sum())
        loss = criterion(out.squeeze(), data.y.squeeze())

    epoch_losses.append(torch.mean(loss.detach().cpu()))

    return correct/len(test_loader.dataset) , np.mean(epoch_losses)#, f1_score, roc_auc, specificity, sensitivity


def set_up_train_environment(dp:bool,
                             model:torch.nn.Module,
                             nr_train_samples:int,
                             epochs:int,
                             train_loader:torch_geometric.loader.DataLoader,
                             clip:float,
                             learning_rate:float,
                             batch_size:int,
                             max_epsilon:float=None):

    fmodel, params, buffers = make_functional_with_buffers(model)

    if dp:
        optimizer = torch.optim.SGD(params, lr=learning_rate)
        criterion = torch.nn.CrossEntropyLoss(reduction="none")
        NOISE = get_noise_multiplier(target_epsilon=max_epsilon, target_delta=1/nr_train_samples, sample_rate=1/len(train_loader), epochs=epochs)
        optimizer = DPOptimizer(
                            optimizer,
                            noise_multiplier=NOISE,
                            max_grad_norm=clip,
                            expected_batch_size=batch_size,
                            loss_reduction="mean",
                        )
        train_loader = wrap_data_loader(data_loader=train_loader, max_batch_size=batch_size, optimizer=optimizer)
        torch.set_grad_enabled(False)
        return fmodel, params, buffers, optimizer, criterion, train_loader
    else:
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.NAdam(model.parameters(), lr=learning_rate)
        torch.set_grad_enabled(True)

        return model, params, buffers, optimizer, criterion, train_loader

  compute_per_sample_grads = jacrev(compute_loss)


In [5]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = TUDataset(root='data/TUD', name='MUTAG')  # 188 graphs

#Dataset does not have num_tasks attribute - create it!
dataset.num_tasks = 2
#dataset.num_node_features = dataset.num_features

num_graphs = len(dataset)
train = int(0.7*num_graphs)
random_idx = torch.randperm(num_graphs)
train_idx = random_idx[:train]
test_idx = random_idx[train:]

train = dataset[train_idx]
tester = dataset[test_idx]

#Make batches of data
trainer = DataLoader(train, batch_size=32,shuffle = True)
tester = DataLoader(tester, batch_size=32,shuffle = True)

model = FingerprintsModel(16, dataset,"GCN").to(device)

In [6]:
fmodel, params, buffers, optimizer, criterion, train_loader = set_up_train_environment(dp=True, model = model, nr_train_samples=len(dataset), epochs = 50, train_loader=trainer, clip = 1.0, learning_rate = .01, batch_size = 32, max_epsilon= 5)

  fmodel, params, buffers = make_functional_with_buffers(model)


In [7]:
mean_poch_loss, acc, params = train_dp(fmodel=fmodel, params=params, buffers=buffers, optimizer=optimizer, criterion= criterion, train_loader = train_loader, device = device)

Iteration: 100%|██████████| 5/5 [00:01<00:00,  4.50it/s]


In [8]:
mean_poch_loss

np.float32(0.6517491)

In [9]:
acc

0.6793893129770993

In [10]:
test_acc, test_mean_loss = test(model, tester, criterion, device)

100%|██████████| 2/2 [00:00<00:00, 51.16it/s]


In [11]:
test_acc

0.631578947368421

In [12]:
test_mean_loss

np.float32(0.64881104)

# Observe how this defense holds up against attacks


## Member Inference Attack