# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=analytic', 
                                                      'case=0_sanity_check', 
                                                      'case/data=ImageNet'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

These configurations run the sanity check which is a simple linear model. Recovering information from the linear model is trivial, especially if the user data has unique class labels.

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 49

cfg.case.user.local_diff_privacy.gradient_noise = 0.0
cfg.case.user.local_diff_privacy.per_example_clipping = 0.0

cfg.case.server.has_external_data = False

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# True user data is returned only for analysis

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

In [None]:
user.plot(reconstructed_user_data)

# Visualize differences

The problem is not well-conditioned, and even in this setting, information is lost due to floating-point precision after the division of weight and bias gradients and the reconstruction is only near-perfect.#

In [None]:
diff_data = dict(data=(reconstructed_user_data['data'] - true_user_data['data']).pow(2).sqrt(),
                 labels=shared_data['labels'])

In [None]:
user.plot(diff_data, scale=True)

### Sept 14: CIFAR? 

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=analytic', 
                                                      'case=0_sanity_check', 
                                                      'case/data=CIFAR10'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

In [None]:
# cfg.dryrun = False
# cfg.attack.optim.step_size=1.0
cfg.case.user.data_idx = 0
# cfg.case.model = 'resnet50'
#cfg.case.server.model_state='moco'
cfg.case.user.num_data_points = 64

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# True user data is returned only for analysis

In [None]:
user.plot(true_user_data)

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, dryrun=cfg.dryrun)

# How good is the reconstruction?
#metrics = breaching.analysis.report(reconstructed_user_data, 
#                                    true_user_data, server_payload, server.model, setup)

In [None]:
user.plot(reconstructed_user_data)

### CIFAR but with Imprint module?

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case/server=malicious-model',
                                                      'case=0_sanity_check', 
                                                      'case/data=CIFAR10'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

In [None]:
# cfg.dryrun = False
# cfg.attack.optim.step_size=1.0
cfg.case.user.data_idx = 0
cfg.case.user.num_data_points = 64

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# True user data is returned only for analysis

In [None]:
user.plot(true_user_data)

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, server.secrets, 
                                                      dryrun=cfg.dryrun)

# How good is the reconstruction?
#metrics = breaching.analysis.report(reconstructed_user_data, 
#                                    true_user_data, server_payload, server.model, setup)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, 
                                   true_user_data, server_payload, server.model, setup, order_batch=True)

In [None]:
ordered_user_data = dict(data=reconstructed_user_data['data'][metrics['order']], 
                         labels=reconstructed_user_data['labels'][metrics['order']])

In [None]:
user.plot(ordered_user_data)