# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    '''
    cfg = hydra.compose(config_name='cfg', overrides=['attack=invertinggradients',
                                                      'case=1_single_image_small'])
    '''
    cfg = hydra.compose(config_name='cfg', overrides=['attack=invertinggradients',
                                                  'case=7_small_batch_cifar_pathnet.yaml'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
num_paths = 8
cfg.case.user.data_idx = 0
cfg.case.user.num_data_points = 1
cfg.case.num_paths = num_paths
cfg.case.server.num_paths = num_paths
#cfg.attack.objective.type = 'euclidean'
cfg.case.server.num_bins = num_paths
#cfg.case.model='ConvNetSmall'
print(cfg.case.server.num_paths)

cfg.case.user.num_data_points = 10
cfg.case.server.model_state = 'orthogonal'
# The total variation scale should be small for CIFAR images
cfg.attack.regularization.total_variation.scale = 1e-4

cfg.attack.objective.type='masked-cosine-similarity'
cfg.attack.optim.signed=False

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

In [None]:
print([(i, k, v.shape) for i, (k,v) in enumerate(server.model.named_parameters())])

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
[(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
torch.stack([g.pow(2).sum() for g in shared_data['gradients'][0][:-4]]).sum().sqrt()

In [None]:
user.plot(true_user_data)

In [None]:
results = breaching.analysis.metrics.gradient_uniqueness(user.model, user.loss, true_user_data, 
                                                            server_payload, setup, fudge=1e-5)
unique_entries, average_hits_per_entry, nonzero_uniques, nonzero_hits_per_entry, uniques, uniques_nonzero = results
print(f'Unique entries (hitting 1 or all): {unique_entries:.2%}, average hits: {average_hits_per_entry:.2%} \n'
      f'Stats (as N hits:val): {dict(zip(uniques[0].tolist(), uniques[1].tolist()))}\n'
      f'Unique nonzero (hitting 1 or all): {nonzero_uniques:.2%} Average nonzero: {nonzero_hits_per_entry:.2%}. \n'
      f'nonzero-Stats (as N hits:val): {dict(zip(uniques_nonzero[0].tolist(), uniques_nonzero[1].tolist()))}')

In [None]:
# sparsity
(shared_data['gradients'][0][0].abs() > 1e-7).sum() / shared_data['gradients'][0][0].numel()

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=True)