# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case/server=malicious-model']) 
    #                                                      'case=1_single_image_small'
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 64
cfg.case.server.model_modification.type = 'ImprintBlock' 
cfg.case.server.model_modification.linfunc = "fourier"
cfg.case.server.model_modification.num_bins = 128
# cfg.case.server.model_modification.position = None  #3.0.conv?
cfg.case.server.model_modification.connection = 'add'
cfg.case.user.provide_labels = True

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
# server.model[1].linear0.reset_parameters()
# server.model[1].linear0.weight.data = server.model[1].linear0.weight.cumsum(dim=0)
# server.model[1].linear0.bias.data = server.model[1].linear0.bias.cumsum(dim=0)

# server.model[1].linear0.weight.data = true_user_data["data"].view(-1, 3*224*224)
# corr = true_user_data["data"].view(64, -1).matmul(true_user_data["data"].view(64, -1).T).diag() / 2
# server.model[1].linear0.bias.data = corr

# server.model[1].linear2.reset_parameters()

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
user.plot(true_user_data)

# Reconstruct user data

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
user.plot(reconstructed_user_data, scale=False)

In [None]:
# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
user.plot(reconstructed_user_data, scale=False)

In [None]:
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
W = shared_data["gradients"][0]
b = shared_data["gradients"][1]
B = shared_data["metadata"]["num_data_points"]

x = true_user_data["data"].view(B, -1)

In [None]:
W.shape, b.shape, x.shape

In [None]:
# Want to solve for: 10xB solutions

In [None]:
# correct a_{it}:

In [None]:
y = user.model[1].linear0(x)
activations = F.relu(y)
try:
    outputs = user.model[3][1](user.model[1].linear2(activations))
except:
    outputs = user.model[3]((x + activations.mean(dim=1, keepdim=True)).view(-1, 3, 224, 224))
loss = user.loss(outputs, true_user_data["labels"])
a, = torch.autograd.grad(loss, [y])

In [None]:
recovered_grad = (a[:, :, None] *  x[:, None, :]).sum(dim=0)

In [None]:
torch.dist(W, recovered_grad), torch.dist(b, a.sum(dim=0))

In [None]:
plt.imshow(a)

In [None]:
plt.imshow(activations.clone().detach())

In [None]:
(a[:, :, None] *  x[:, None, :]).shape

In [None]:
plt.imshow((a[:, :, None] *  x[:, None, :]).sum(dim=-1))

In [None]:
a.unique(), len(a.unique())

In [None]:
# Recover x from a and W:

In [None]:
x.mean(dim=1)

In [None]:
W.shape, a.shape

In [None]:
P = 1 / a[:,0:1]
P.shape, W.shape

In [None]:
a.shape

In [None]:
x_rec = torch.linalg.lstsq(W, a.T, driver="gelsy", rcond=1e-16) #‘gels’, ‘gelsy’, ‘gelsd, ‘gelss’
# P = 1 / a[:,0:1]
# x_rec = torch.linalg.lstsq(W, (a * P).T, driver="gelsy", rcond=1e-16)
x_rec

In [None]:
torch.dist(x_rec.solution.T * x.norm() / x_rec.solution.T.norm(), x)

In [None]:
img = x_rec.solution.T[0].view(3, 224, 224).permute(1, 2, 0)
img_scaled = (img - img.min()) / (img.max() - img.min())

In [None]:
plt.imshow(img_scaled)

In [None]:
data = x_rec.solution.T.view(B, 3, 224, 224)
min_val, max_val = data.amin(dim=[2, 3], keepdim=True), data.amax(dim=[2, 3], keepdim=True)
data = ((data - min_val) / (max_val - min_val) - attacker.dm) / attacker.ds
fake_rec_data = dict(data=data, labels=true_user_data["labels"])
user.plot(fake_rec_data, scale=False)

In [None]:
metrics = breaching.analysis.report(fake_rec_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
K = torch.triu(torch.ones(128, 128))  # / torch.arange(1, 129)
D = K.inverse()
Wp = server_payload["queries"][0]["parameters"][0]

In [None]:
b

In [None]:
rec = dict(data=(D@W / (D@b)[:, None]).view(-1, 3, 224, 224), labels=None)

In [None]:
user.plot(rec, scale=False)

In [None]:
D@W / (D@b)[:, None]

In [None]:
a.sum(dim=0).shape