# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case=1_single_image_small'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.model='convnetsmall'

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# [(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

In [None]:
grads = shared_data['gradients'][0]
named_grads = {name: g for (g, (name, param)) in zip(grads, user.model.named_parameters())}
named_modules = {name: module for name, module in user.model.named_modules()}
named_modules

In [None]:
(named_grads['model.linear.weight'] / named_grads['model.linear.bias'][:, None]).shape

In [None]:
named_grads.keys()

In [None]:
valid_classes = named_grads['model.linear.bias'] != 0
named_grads_fc_debiased = named_grads['model.linear.weight'][valid_classes] \
                           / named_grads['model.linear.bias'][valid_classes, None]

### Replicate debiased grad

In [None]:
presum = user.model(true_user_data['data']).sum()
debiased_rec, = torch.autograd.grad(presum, user.model.model.linear.weight)

In [None]:
torch.dist(named_grads_fc_debiased, debiased_rec)

In [None]:
user.model.model.linear = torch.nn.Identity()

In [None]:
torch.dist(named_grads_fc_debiased[(named_grads['model.linear.bias'] < 0).nonzero()].squeeze(),
          user.model(true_user_data['data']))

$l = h(cx + b)$, $x\in \R^n$, $c, b \in \R$, $y=cx+b$

$\frac{\partial h}{\partial y_i} = g_i $

$\frac{\partial h}{\partial b} = \sum_{i=1}^n \frac{\partial h}{\partial y_i} = \langle g, 1\rangle $

$\frac{\partial h}{\partial c} = \sum_{i=1}^n \frac{\partial h}{\partial y_i} x_i = \langle g, x\rangle$

Wish:  $\langle 1, x\rangle $ or any $f: \R^n \to \R$ mapping $x$ to a scalar without $g$

Can do $\frac{\langle g, x\rangle}{\langle g, 1 \rangle} kinda smaller ||x|| $ but not great

### Now do the same for a conv + batchnorm layer

In [None]:
valid_channels = named_grads['stem.1.bias'] != 0
correction = named_modules['stem.1'].running_var / named_modules['stem.1'].weight
divisor = named_grads['stem.0.weight'][valid_channels] / named_grads['stem.1.bias'][valid_channels, None, None, None]
debiased_conv1 = divisor * correction[valid_channels, None, None, None]

# conv1.weight, bn1.weight, bn1.bias

In [None]:
presum = user.model.stem[0](true_user_data['data'])[:, valid_channels].sum()
debiased_conv_rec, = torch.autograd.grad(presum, user.model.stem[0].weight)

In [None]:
torch.dist(debiased_conv_rec, debiased_conv1)

In [None]:
named_grads['stem.1.bias']

In [None]:
debiased_conv_rec[0]

## Simpler case: conv model

In [None]:
valid_channels = named_grads['model.conv0.bias'] != 0
divisor = named_grads['model.conv0.weight'][valid_channels] / named_grads['model.conv0.bias'][valid_channels, None, None, None]
debiased_conv1 = divisor

In [None]:
presum = user.model.model.conv0(true_user_data['data'])[:, valid_channels].sum()
debiased_conv_rec, = torch.autograd.grad(presum, user.model.model.conv0.weight)

In [None]:
torch.dist(debiased_conv_rec, debiased_conv1)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=True)