In [1]:
import torch
import torchvision
from collections import namedtuple
import os
import matplotlib.pyplot as plt
import torch.nn as nn

from attacks.analytic_attack import ImprintAttacker
from modifications.imprint import ImprintBlock
from utils.breaching_utils import *
from opacus import GradSampleModule
%load_ext autoreload
%autoreload 2

  Referenced from: <2BD1B165-EC09-3F68-BCE4-8FE4E70CA7E2> /opt/homebrew/lib/python3.11/site-packages/torchvision/image.so
  warn(


# Attack begins here:

### Initialize your model

In [3]:
setup = dict(device=torch.device("cpu"), dtype=torch.float)

# This could be any model:
model = torchvision.models.resnet18(num_classes=7)

model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
# It will be modified maliciously:
input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]
num_bins = 100 # Here we define number of imprint bins



### And your dataset (ImageNet by default)

In [4]:
import medmnist
from medmnist import INFO, Evaluator

batch_size = 8 # Number of images in the user's batch. We have a small one here for visualization purposes
import random
random.seed(234324) # You can change this to get a new batch.

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
    ]
)


data_flag = 'dermamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split="val", transform=transforms, download=True, size=224)
samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
data = torch.stack([sample[0] for sample in samples])
labels = torch.tensor([sample[1] for sample in samples]).flatten()

Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz


  labels = torch.tensor([sample[1] for sample in samples]).flatten()


In [5]:
import numpy as np
print(np.unique(dataset.labels))

[0 1 2 3 4 5 6]


### Simulate an attacked FL protocol

In [6]:
from torchsummary import summary
summary(model, input_size=(3, 224, 224))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [7]:
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from torch.utils.data import DataLoader

model = ModuleValidator.fix(model)
model = GradSampleModule(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Fix: Added missing code for optimizer initialization
training_set = DataClass(split="train", transform=transforms, download=True, size=224)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

#if hasattr(model, "autograd_grad_sample_hooks"):
#    del model.autograd_grad_sample_hooks

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=training_loader,
    noise_multiplier=1.1,
    max_grad_norm=1.0,
    poisson_sampling=False,
    grad_sample_mode="hooks",

)




Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz




In [8]:
criterion = nn.CrossEntropyLoss()

# Training function for classification
def train_classification(model, optimizer, data_loader, criterion, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(data_loader):
            optimizer.zero_grad()
            outputs = model(inputs)

            # Flatten labels if necessary (assuming labels shape is (batch_size, 1))
            labels = labels.squeeze()

            # Print outputs and labels for debugging
            print(f"Batch {i+1}/{len(data_loader)}")
            # print("Outputs:", outputs)
            # print("Labels:", labels)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(data_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Example usage
# train_classification(model, optimizer, training_loader, criterion, num_epochs=5)

# Train the model for classification
train_classification(model, optimizer, training_loader, criterion, num_epochs=1)



Batch 1/876
Batch 2/876
Batch 3/876
Batch 4/876
Batch 5/876
Batch 6/876
Batch 7/876
Batch 8/876
Batch 9/876
Batch 10/876
Batch 11/876
Batch 12/876
Batch 13/876
Batch 14/876
Batch 15/876
Batch 16/876
Batch 17/876
Batch 18/876
Batch 19/876
Batch 20/876
Batch 21/876
Batch 22/876
Batch 23/876
Batch 24/876
Batch 25/876
Batch 26/876
Batch 27/876
Batch 28/876
Batch 29/876
Batch 30/876
Batch 31/876
Batch 32/876
Batch 33/876
Batch 34/876
Batch 35/876
Batch 36/876
Batch 37/876
Batch 38/876
Batch 39/876
Batch 40/876
Batch 41/876
Batch 42/876
Batch 43/876
Batch 44/876
Batch 45/876
Batch 46/876
Batch 47/876
Batch 48/876
Batch 49/876
Batch 50/876
Batch 51/876
Batch 52/876
Batch 53/876
Batch 54/876
Batch 55/876
Batch 56/876
Batch 57/876
Batch 58/876
Batch 59/876
Batch 60/876
Batch 61/876
Batch 62/876
Batch 63/876
Batch 64/876
Batch 65/876
Batch 66/876
Batch 67/876
Batch 68/876
Batch 69/876
Batch 70/876
Batch 71/876
Batch 72/876
Batch 73/876
Batch 74/876
Batch 75/876
Batch 76/876
Batch 77/876
Batch 78

In [14]:
model_trained = model


In [15]:
model_trained.eval()
block = ImprintBlock(input_dim, num_bins=num_bins)
model_trained = torch.nn.Sequential(
    torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model_trained
)
secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)
secrets = {"ImprintBlock": secret}

In [16]:
# This is the attacker:
attacker = ImprintAttacker(model_trained, loss_fn, attack_cfg_default, setup)

# Server-side computation:
queries = [dict(parameters=[p for p in model_trained.parameters()], buffers=[b for b in model_trained.buffers()])]
server_payload = dict(queries=queries, data=data_cfg_default)
# User-side computation:
loss = loss_fn(model_trained(data), labels)


In [17]:
shared_data = dict(
    gradients = [param.grad for param in model_trained.parameters() if param.grad is not None],
    buffers=None,
    num_data_points=1,
    labels=labels.flatten(),
    local_hyperparams=None,
)

### Reconstruct data from the update

In [20]:
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)

torch.Size([3, 7, 7])
torch.Size([3, 7, 7])
torch.Size([3, 7, 7])


IndexError: too many indices for tensor of dimension 3

In [None]:
print

In [None]:
# Metrics?:
from utils.analysis import report
true_user_data = {'data': data, 'labels': labels}
metrics = report(reconstructed_user_data,
    true_user_data,
    server_payload,
    model, compute_ssim=False) # Can change to true and install a package...
print(f"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}, SSIM: {metrics['ssim']} ")

### Plot ground-truth data

In [None]:
plot_data(data_cfg_default, true_user_data, setup)

# Create the "images" folder if it doesn't exist
if not os.path.exists("images"):
    os.makedirs("images")

# Save the images inside the "images" folder
plt.savefig("images/true_user_data.png")


### Now plot reconstructed data

In [None]:
plot_data(data_cfg_default, reconstructed_user_data, setup)
# Save the images inside the "images" folder
plt.savefig("images/reconstructed_user_data.png")