# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case/server=malicious-model',
                                                      'case=6_small_batch_cifar'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.double)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.data_idx = 200
cfg.case.model='linear'

cfg.case.user.num_data_points = 5
cfg.case.server.model_modification = {'DifferentialBlock' : dict(num_bins = 3)}

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

In [None]:
user.model

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
[(g.mean().item(), g.std().item()) for g in shared_data['gradients'][0]]

In [None]:
torch.stack([g.pow(2).sum() for g in shared_data['gradients'][0][:-4]]).sum().sqrt()

In [None]:
user.plot(true_user_data)

# Hand inspection

In [None]:
weight_grad = shared_data['gradients'][0][0]
bias_grad = shared_data['gradients'][0][1]

# bin_sizes = 1 / (server_payload['queries'][0]['parameters'][0].mean(dim=1) * 3072).detach().to(**setup)
bin_sizes = server_payload['queries'][0]['parameters'][2].diag().detach()
bins = server_payload['queries'][0]['parameters'][1].to(**setup)
bins

In [None]:
true_images  = true_user_data['data']
average_image = true_images.sum(dim=0)
true_images.mean(dim=[1, 2, 3])

In [None]:
bias_grad != 0

In [None]:
(true_images.mean(dim=[1,2,3])[:, None] + bins[None, :]) > 0

In [None]:
bias_grad, bins, 1 / 3072

In [None]:
weight_grad.norm(dim=-1)

In [None]:
bias_grad[None, :] / bias_grad[:, None]

In [None]:
bias_grad.sum(), bias_grad

In [None]:
ship = weight_grad[2] / bias_grad[2]
dog = (weight_grad[1] + weight_grad[2]) / (bias_grad[1] + bias_grad[2])



# dog = ((1 - bias_grad[2]) * weight_grad[1] - (-1 - 0.476) * weight_grad[2]) / (bias_grad[1] + bias_grad[2]) * 1.1718521118164062




print(f'Image 1 accurate to {(ship - true_images[0].view(-1)).norm()}')
# print(f'Image 2 accurate to {(dog - true_images[1].view(-1)).norm()}')
print(f'Image 2 accurate to {(dog / dog.norm() - true_images[1].view(-1) / true_images[1].norm()).norm()}')

In [None]:
reconstructed = dict(data=ship.reshape(1, 3, 32, 32), labels=torch.tensor(5))
user.plot(reconstructed, scale=True)

In [None]:
reconstructed = dict(data=true_images.sum(dim=0).reshape(1, 3, 32, 32), labels=torch.randint(1, (1,)))
user.plot(reconstructed, scale=True)

In [None]:
#reconstructed = dict(data=outputs, labels=torch.randint(1, (outputs.shape[0],)))
#user.plot(reconstructed, scale=True)

### Estimate ground truth coefficients

In [None]:
bins

In [None]:
true_images.mean(dim=[1, 2, 3])

In [None]:
gt = true_images[1].view(-1)
gt.shape, weight_grad.T.shape

In [None]:
weight_grad.shape

In [None]:
cutoff = (weight_grad.sum(dim=1) > 0).nonzero()
cutoff = 2

In [None]:
# Eq.: weight_grad * (a1, a2, a3) = gt2

In [None]:
for gt in true_images:
    print(torch.linalg.lstsq(weight_grad[:cutoff].T, gt.view(-1)))
    print()

In [None]:
valid_classes = bias_grad != 0
# intermediates = (weight_grad[valid_classes, :] / bias_grad[valid_classes, None])

intermediates = weight_grad[valid_classes, :] / bias_grad[valid_classes, None]
# intermediates = torch.cat([intermediates, torch.ones(1, 3072, **setup)], dim=0)

for gt in true_images:
    print(torch.linalg.lstsq(intermediates.T, gt.view(-1)))
    print()

In [None]:
intermediates.shape

# Old Stuff

In [None]:
valid_classes = bias_grad != 0
# intermediates = (weight_grad[valid_classes, :] / bias_grad[valid_classes, None])

intermediates = torch.zeros_like(weight_grad)
intermediates[valid_classes] = weight_grad[valid_classes, :] / bias_grad[valid_classes, None]

direct_outputs = intermediates.unflatten(1, (3, 32, 32))

# intermediates = torch.cat([intermediates[0:1], intermediates.diff(dim=0)])
# intermediates = torch.stack([intermediates[0], intermediates[-1]])
differentials = torch.zeros_like(intermediates)
for i in range(differentials.shape[0]):
    if i == differentials.shape[0] - 1:
        differentials[i] = intermediates[i]
    else:
        differentials[i] = (intermediates[i+1] - intermediates[i]) # * bin_sizes[i]
outputs = -differentials.unflatten(1, (3, 32, 32))

In [None]:
reconstructed = dict(data=weight_grad.unflatten(1, (3, 32, 32)), labels=torch.randint(1, (outputs.shape[0],)))
user.plot(reconstructed, scale=True)

In [None]:
reconstructed = dict(data=outputs, labels=torch.randint(1, (outputs.shape[0],)))
user.plot(reconstructed, scale=False)

In [None]:
reconstructed = dict(data=outputs[0][None,:].reshape(1, 3, 32, 32), labels=torch.randint(1, (1,)))
user.plot(reconstructed, scale=False)

In [None]:
metrics = breaching.analysis.report(reconstructed, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)