In [1]:
import torch
import torchvision
from collections import namedtuple
import os
import matplotlib.pyplot as plt
from attacks.analytic_attack import ImprintAttacker
from modifications.imprint import ImprintBlock
from utils.breaching_utils import *

import medmnist
from medmnist import INFO, Evaluator

from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from torch.utils.data import DataLoader

In [2]:
batch_size = 1 # Number of images in the user's batch. We have a small one here for visualization purposes
import random
random.seed(234324) # You can change this to get a new batch.

transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),
    ]
)
data_flag = 'dermamnist'
info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split="val", transform=transforms, download=True, size=224)
samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]
data = torch.stack([sample[0] for sample in samples])
labels = torch.tensor([sample[1] for sample in samples]).flatten()

Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz


  labels = torch.tensor([sample[1] for sample in samples]).flatten()


### Initialize your model

In [3]:
setup = dict(device=torch.device("cpu"), dtype=torch.float)

# This could be any model:
model = torchvision.models.resnet18(pretrained = True)
# Modify the final layer to have 7 output classes
model.fc = torch.nn.Linear(512, 7)

loss_fn = torch.nn.CrossEntropyLoss()




In [4]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
#add opacus here -> problem with the model structure (ImprintBlock) so do it after the imprint block
model = ModuleValidator.fix(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  
training_set = DataClass(split="train", transform=transforms, download=True, size=224)
training_loader = DataLoader(training_set, batch_size=batch_size)

#if hasattr(model, "autograd_grad_sample_hooks"):
#    del model.autograd_grad_sample_hooks

privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=training_loader, 
    noise_multiplier=1.1,
    max_grad_norm=1,
    poisson_sampling= False,
    grad_sample_mode= "hooks",
    #grad_sample_mode="ew",
)

Using downloaded and verified file: /Users/maximilianeckert/.medmnist/dermamnist_224.npz




In [6]:
# It will be modified maliciously:
input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]
num_bins = 100 # Here we define number of imprint bins
block = ImprintBlock(input_dim, num_bins=num_bins)
model = torch.nn.Sequential(
    torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model
)
secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)
secrets = {"ImprintBlock": secret}

In [7]:
# Model training

num_epochs = 1
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in training_loader:
        labels = labels.flatten()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(training_loader)}")

print("Training finished")




Epoch 1, Loss: 3.806044608444461
Training finished


In [11]:

safe_model = model


In [12]:
model_trained = model._modules["3"]
print(model_trained)

GradSampleModule(ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm(32, 64, eps=1e-05, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): GroupNorm(32, 64, eps=1e-05, af

### Simulate an attacked FL protocol

In [13]:

# This is the attacker:
attacker = ImprintAttacker(model_trained, loss_fn, attack_cfg_default, setup)

#Server-side computation:

queries = [dict(parameters=[p for p in model_trained.parameters()], buffers=[b for b in model_trained.buffers()])]
server_payload = dict(queries=queries, data=data_cfg_default)

#User-side computation:
labels_expanded = labels.repeat(2)  # Adjust 2 to match the actual batch size of data


loss = loss_fn(model_trained(data), labels_expanded)



ValueError: Expected input batch_size (1) to match target batch_size (2).

In [None]:
print(model_trained(data))
print(labels_expanded)

tensor([[-28.4729,  -5.0741,  -9.9015, -14.9834,  -8.5387,   4.6916, -19.3648],
        [-28.6136,  -5.5390,  -9.0047, -16.7639,  -8.8951,   5.8726, -17.7156]],
       grad_fn=<AddmmBackward0>)
tensor([5, 5])




In [None]:
shared_data = dict(
    gradients=[torch.autograd.grad(loss, model_trained.parameters())],
    buffers=None,
    num_data_points=1,
    labels=labels,
    local_hyperparams=None,
)

### Reconstruct data from the update

In [None]:
# Attack:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)

RuntimeError: The size of tensor a (7) must match the size of tensor b (64) at non-singleton dimension 2

In [None]:
# Metrics?: 
from utils.analysis import report
true_user_data = {'data': data, 'labels': labels}
metrics = report(reconstructed_user_data,
    true_user_data,
    server_payload,
    model, compute_ssim=False) # Can change to true and install a package...
print(f"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}, SSIM: {metrics['ssim']} ")

### Plot ground-truth data

In [None]:
plot_data(data_cfg_default, true_user_data, setup)

# Create the "images" folder if it doesn't exist
if not os.path.exists("images"):
    os.makedirs("images")

# Save the images inside the "images" folder
plt.savefig("images/true_user_data.png")


### Now plot reconstructed data

In [None]:
plot_data(data_cfg_default, reconstructed_user_data, setup)
# Save the images inside the "images" folder
plt.savefig("images/reconstructed_user_data.png")