In [1]:
import torch
import torchvision
from collections import namedtuple
import os
import matplotlib.pyplot as plt
import torch.nn as nn

from attacks.analytic_attack import ImprintAttacker
from modifications.imprint import ImprintBlock
from utils.breaching_utils import *
from opacus import GradSampleModule
%load_ext autoreload
%autoreload 2

# Attack begins here:

### Initialize your model

In [2]:
setup = dict(device=torch.device("cpu"), dtype=torch.float)

# This could be any model:
model = torchvision.models.resnet18(num_classes=7)

model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
# It will be modified maliciously:
input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]
num_bins = 100 # Here we define number of imprint bins



### And your dataset (ImageNet by default)

In [3]:
import medmnist
from medmnist import INFO, Evaluator

batch_size = 4 # Number of images in the user's batch. We have a small one here for visualization purposes
import random
random.seed(234324) # You can change this to get a new batch.

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
    ]
)


data_flag = 'dermamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split="val", transform=transforms, download=True, size=224)
samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
data = torch.stack([sample[0] for sample in samples])
labels = torch.tensor([sample[1] for sample in samples]).flatten()

Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz


  labels = torch.tensor([sample[1] for sample in samples]).flatten()


In [4]:
block = ImprintBlock(input_dim, num_bins=num_bins)
model = torch.nn.Sequential(
    torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model
)
secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)
secrets = {"ImprintBlock": secret}

In [5]:
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from torch.utils.data import DataLoader

model = ModuleValidator.fix(model)
model = GradSampleModule(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Fix: Added missing code for optimizer initialization
training_set = DataClass(split="train", transform=transforms, download=True, size=224)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

#if hasattr(model, "autograd_grad_sample_hooks"):
#    del model.autograd_grad_sample_hooks

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=training_loader,
    noise_multiplier=1.1,
    max_grad_norm=1.0,
    poisson_sampling=False,
    #grad_sample_mode="hooks",

)




Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz




In [6]:
print(model(data))
print(labels)

tensor([[-1.0280, -1.1038,  0.0524,  0.3591,  0.2203, -0.2732, -0.1617],
        [-1.0279, -1.1035,  0.0522,  0.3590,  0.2204, -0.2732, -0.1618],
        [-1.0276, -1.1025,  0.0516,  0.3589,  0.2206, -0.2731, -0.1621],
        [-1.0275, -1.1021,  0.0514,  0.3588,  0.2206, -0.2731, -0.1622]],
       grad_fn=<AddmmBackward0>)
tensor([5, 5, 5, 5])




In [7]:
criterion = nn.CrossEntropyLoss()

# Training function for classification
def train_classification(model, optimizer, data_loader, criterion, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(data_loader):
            optimizer.zero_grad()
            outputs = model(inputs)

            # Flatten labels if necessary (assuming labels shape is (batch_size, 1))
            labels = labels.squeeze()

            # Print outputs and labels for debugging
            print(f"Batch {i+1}/{len(data_loader)}")
            # print("Outputs:", outputs)
            # print("Labels:", labels)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(data_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Example usage
# train_classification(model, optimizer, training_loader, criterion, num_epochs=5)

# Train the model for classification
train_classification(model, optimizer, training_loader, criterion, num_epochs=1)

Batch 1/1752


ValueError: Per sample gradient is not initialized. Not updated in backward pass?

In [None]:
model_trained = model


In [None]:
# This is the attacker:
attacker = ImprintAttacker(model_trained, loss_fn, attack_cfg_default, setup)

# Server-side computation:
queries = [dict(parameters=[p for p in model_trained.parameters()], buffers=[b for b in model_trained.buffers()])]
server_payload = dict(queries=queries, data=data_cfg_default)
# User-side computation:
loss = loss_fn(model_trained(data), labels)


In [None]:
shared_data = dict(
    gradients = [param.grad for param in model_trained.parameters() if param.grad is not None],
    buffers=None,
    num_data_points=1,
    labels=labels.flatten(),
    local_hyperparams=None,
)

### Reconstruct data from the update

In [None]:
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)

In [None]:
print

In [None]:
# Metrics?:
from utils.analysis import report
true_user_data = {'data': data, 'labels': labels}
metrics = report(reconstructed_user_data,
    true_user_data,
    server_payload,
    model, compute_ssim=False) # Can change to true and install a package...
print(f"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}, SSIM: {metrics['ssim']} ")

### Plot ground-truth data

In [None]:
plot_data(data_cfg_default, true_user_data, setup)

# Create the "images" folder if it doesn't exist
if not os.path.exists("images"):
    os.makedirs("images")

# Save the images inside the "images" folder
plt.savefig("images/true_user_data.png")


### Now plot reconstructed data

In [None]:
plot_data(data_cfg_default, reconstructed_user_data, setup)
# Save the images inside the "images" folder
plt.savefig("images/reconstructed_user_data.png")