In [1]:
import torch
import torchvision
from collections import namedtuple
import os
import matplotlib.pyplot as plt




In [2]:
from attacks.analytic_attack import ImprintAttacker
from modifications.imprint import ImprintBlock
from utils.breaching_utils import *

# Attack begins here:

### Initialize your model

In [3]:
setup = dict(device=torch.device("cpu"), dtype=torch.float)

# This could be any model:
model = torchvision.models.resnet18()
model.eval()
loss_fn = torch.nn.CrossEntropyLoss()
# It will be modified maliciously:
input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]
num_bins = 100 # Here we define number of imprint bins
block = ImprintBlock(input_dim, num_bins=num_bins)
model = torch.nn.Sequential(
    torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model
)
secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)
secrets = {"ImprintBlock": secret}

### And your dataset (ImageNet by default)

In [4]:
# transforms = torchvision.transforms.Compose(
#     [
#         torchvision.transforms.Resize(256),
#         torchvision.transforms.CenterCrop(224),
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
#     ]
# )
# dataset = torchvision.datasets.ImageNet(root="~/data/imagenet/", split="val", transform=transforms)
# batch_size = 64 # Number of images in the user's batch. We have a small one here for visualization purposes
# import random
# random.seed(123) # You can change this to get a new batch. 
# samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
# data = torch.stack([sample[0] for sample in samples])
# labels = torch.tensor([sample[1] for sample in samples])

In [5]:
import medmnist
from medmnist import INFO, Evaluator

batch_size = 8 # Number of images in the user's batch. We have a small one here for visualization purposes
import random
random.seed(234324) # You can change this to get a new batch.

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
    ]
)


data_flag = 'dermamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split="val", transform=transforms, download=True, size=224)
samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
data = torch.stack([sample[0] for sample in samples])
labels = torch.tensor([sample[1] for sample in samples]).flatten()

Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz


  labels = torch.tensor([sample[1] for sample in samples]).flatten()


In [6]:
print(labels)

tensor([5, 5, 5, 5, 0, 5, 5, 5])


### Simulate an attacked FL protocol

In [8]:
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from torch.utils.data import DataLoader

model = ModuleValidator.fix(model)
optimizer = torch.optim.SGD(lr=0.01)  # Fix: Added missing code for optimizer initialization
training_set = DataClass(split="train", transform=transforms, download=True, size=224)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

#if hasattr(model, "autograd_grad_sample_hooks"):
#    del model.autograd_grad_sample_hooks

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=training_loader,  # Use the DataLoader here
    noise_multiplier=100,
    max_grad_norm=1.0,
    poisson_sampling= False,
    grad_sample_mode = 

)


Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz




In [None]:
# This is the attacker:
attacker = ImprintAttacker(model, loss_fn, attack_cfg_default, setup)

# Server-side computation:
queries = [dict(parameters=[p for p in model.parameters()], buffers=[b for b in model.buffers()])]
server_payload = dict(queries=queries, data=data_cfg_default)
# User-side computation:
loss = loss_fn(model(data), labels)
shared_data = dict(
    gradients=[torch.autograd.grad(loss, model.parameters())],
    buffers=None,
    num_data_points=1,
    labels=labels,
    local_hyperparams=None,
)

### Reconstruct data from the update

In [None]:
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)

In [9]:
# Metrics?: 
from utils.analysis import report
true_user_data = {'data': data, 'labels': labels}
metrics = report(reconstructed_user_data,
    true_user_data,
    server_payload,
    model, compute_ssim=False) # Can change to true and install a package...
print(f"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}, SSIM: {metrics['ssim']} ")

NameError: name 'reconstructed_user_data' is not defined

### Plot ground-truth data

In [None]:
plot_data(data_cfg_default, true_user_data, setup)

# Create the "images" folder if it doesn't exist
if not os.path.exists("images"):
    os.makedirs("images")

# Save the images inside the "images" folder
plt.savefig("images/true_user_data.png")


### Now plot reconstructed data

In [None]:
plot_data(data_cfg_default, reconstructed_user_data, setup)
# Save the images inside the "images" folder
plt.savefig("images/reconstructed_user_data.png")