# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [1]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [2]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case=8_industry_fed_avg'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cpu:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

Investigating use case industry_fed_avg with server type malicious_model.
Attack settings are:
type: analytic
attack_type: imprint-readout



{'device': device(type='cpu'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [3]:
users = 100

cfg.case.user.num_users = users
cfg.case.user.num_data_points = 100 * users
cfg.case.user.num_local_updates = 10
cfg.case.user.num_data_per_local_update_step = 10
cfg.case.user.local_learning_rate = 1e-4

cfg.case.examples_from_split = 'training' #'training'

cfg.case.user.user_type= 'multiuser_aggregate' 
#cfg.case.user.user_type= 'local_update'

In [4]:
cfg.case.user.data_with_labels = 'same' # just bet one 1-in-4 :>


cfg.case.model = 'none'

cfg.case.server.model_modification.type = 'OneShotBlock' 
cfg.case.server.model_modification.num_bins = cfg.case.user.num_data_points
cfg.case.server.model_modification.position = None # '4.0.conv'
cfg.case.server.model_modification.connection = 'add'


cfg.case.server.model_modification.linfunc = 'fourier'
cfg.case.server.model_modification.mode = 32

### Instantiate all parties

In [5]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

Model architecture <class 'torch.nn.modules.container.Sequential'> loaded with 301,058 parameters and 0 buffers.
Overall this is a data ratio of       0:1 for target shape [10000, 3, 224, 224] given that num_queries=1.


In [6]:
print(user)
print(server)
print(attacker)

User (of type MultiUserAggregate with settings:
            number of data points: 10000
            number of user queries 1

            Threat model:
            User provides labels: True
            User provides number of data points: True

            Model:
            model specification: Sequential
            loss function: CrossEntropyLoss()

            Data:
            Dataset: ImageNet
            data_idx: 743981
        
<breaching.cases.servers.MaliciousModelServer object at 0x7f9f95cb90a0>
<breaching.attacks.analytic_attack.ImprintAttacker object at 0x7f9f95c7b280>


In [7]:
user.model[1].bins

[0.0, 0.0001414355002588186]

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [8]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Now loading data for user 0 ...

RuntimeError: DataLoader worker (pid(s) 13306) exited unexpectedly

In [None]:
shared_data['gradients']

# Reconstruct user data

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
reconstructed_user_data['data'].shape

In [None]:
found_data = dict(data = reconstructed_user_data['data'][1:2], labels=None)
user.plot(found_data, scale=False)

### Identify id of user data with this value:

In [None]:
matches = dict()
for idx, (data, label) in enumerate(user.dataloader.dataset):
    matches[idx] = torch.dist(found_data['data'], data.to(**setup))
    if matches[idx] < 1:
        break
    if idx % 1000 == 0:
        print(f'Currently at index {idx}')
idx = min(matches, key=matches.get)
print(idx)
true_data = user.dataloader.dataset[idx]
matching_user_data = dict(data = true_data[0][None,...], labels=true_data[1])
user.plot(matching_user_data, scale=False)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(found_data, matching_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=False, compute_full_iip=False, skip_rpsnr=True)