# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case=8_industry_fed_avg'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cpu:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
users = 100

cfg.case.user.num_users = users
cfg.case.user.num_data_points = 100 * users
cfg.case.user.num_local_updates = 10
cfg.case.user.num_data_per_local_update_step = 10
cfg.case.user.local_learning_rate = 1e-4

cfg.case.data.examples_from_split = 'training' #'training'

cfg.case.user.user_type= 'multiuser_aggregate' 
#cfg.case.user.user_type= 'local_update'

In [None]:
cfg.case.model = 'none' # Save some memory given that the rest of the model is almost irrelevant

cfg.case.server.model_modification.type = 'OneShotBlock' 
cfg.case.server.model_modification.num_bins = cfg.case.user.num_data_points
cfg.case.server.model_modification.position = None # '4.0.conv'
cfg.case.server.model_modification.connection = 'add'


cfg.case.server.model_modification.linfunc = 'fourier'
cfg.case.server.model_modification.mode = 32

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

In [None]:
user.model[1].bins

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
shared_data['gradients']

# Reconstruct user data

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
reconstructed_user_data['data'].shape

In [None]:
found_data = dict(data = reconstructed_user_data['data'][1:2], labels=None)
user.plot(found_data, scale=False)

### Identify id of user data with this value:

In [None]:
matches = dict()
for idx, (data, label) in enumerate(user.dataloader.dataset):
    matches[idx] = torch.dist(found_data['data'], data.to(**setup))
    if matches[idx] < 1:
        break
    if idx % 1000 == 0:
        print(f'Currently at index {idx}')
idx = min(matches, key=matches.get)
print(idx)
true_data = user.dataloader.dataset[idx]
matching_user_data = dict(data = true_data[0][None,...], labels=true_data[1])
user.plot(matching_user_data, scale=False)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(found_data, matching_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)