# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [1]:
import torch
import numpy as np
import hydra
import copy
from omegaconf import OmegaConf, open_dict
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [2]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=clsattack'])
    cfg.case.server.name = 'class_malicious_parameters'
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

ANTLR runtime and generated code versions disagree: 4.7.2!=4.8
ANTLR runtime and generated code versions disagree: 4.7.2!=4.8
ANTLR runtime and generated code versions disagree: 4.7.2!=4.8
ANTLR runtime and generated code versions disagree: 4.7.2!=4.8
Investigating use case single_imagenet with server type class_malicious_parameters.


{'device': device(type='cuda'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [11]:
with open_dict(cfg):
    cfg.case.data.partition = "unique-class"
    cfg.case.user.user_idx = 0
    cfg.case.user.num_data_points = 8
    cfg.case.model = 'resnet18'
    cfg.case.server.model_state = 'trained'
    cfg.case.user.provide_labels = True
    cfg.case.user.provide_buffers = True
    cfg.case.user.provide_num_data_points = True

    cfg.attack.optim.max_iterations = 0

### Instantiate all parties

In [12]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Model architecture resnet18 loaded with 11,380,173 parameters and 9,620 buffers.
Overall this is a data ratio of       9:1 for target shape [8, 3, 224, 224] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 8

    Threat model:
    User provides labels: True
    User provides buffers: True
    User provides number of data points: True

    Data:
    Dataset: ImageNetAnimals
    user: 0
    
        
Server (of type ClassParameterServer) with settings:
    Threat model: Honest-but-curious
    Number of planned queries: 1
    Has external/public data: False

    Model:
        model specification: resnet18
        model state: trained
        public buffers: True

    Secrets: {}
    
Attacker (of type OptimizationBasedAttacker) with settings:
    Hyperparameter Template: invertinggradients

    Objective: Cosine Similarity with scale=1.0 and task reg=0.0
    Regularizers: Total Variation, scale=0.2. p=2 q=0.5. Color TV: double oppponents
 

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [33]:
# modify the model to catch the class we want
which_to_recover = 0
how_many = 1
how_many_rec = 1

# preparing for the attack
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)
t_labels = shared_data['metadata']['labels'].cpu().detach().numpy()
cls_to_obtain = t_labels[which_to_recover:(which_to_recover + how_many)]
extra_info = {'cls_to_obtain': cls_to_obtain}

server.reset_model()
server.reconfigure_model('cls_attack', extra_info=extra_info)
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)
avg_feature = torch.flatten(server.reconstruct_feature(shared_data, cls_to_obtain))
feat_to_obtain = int(torch.argmax(avg_feature))
feat_value = float(avg_feature[feat_to_obtain])

extra_info['feat_to_obtain'] = feat_to_obtain
extra_info['feat_value'] = feat_value
extra_info['multiplier'] = 1
extra_info['non_target_logit'] = 0
extra_info['db_flip'] = 1

# iteratively get gradients back
recovered_single_gradients = server.binary_attack(user, extra_info)

# reorder the gradients
server.reset_model()
extra_info['multiplier'] = 1
extra_info['feat_value'] = feat_value
server.reconfigure_model('cls_attack', extra_info=extra_info)
server.reconfigure_model('feature_attack', extra_info=extra_info)
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload) 

attacker.objective.cfg_impl = cfg.attack.impl
single_gradients, single_losses = server.cal_single_gradients(attacker, true_user_data)
recovered_single_gradients = server.order_gradients(recovered_single_gradients, single_gradients)

3.7050085067749023 2.0416505336761475 5.368366479873657
2.0416505336761475 1.2761552333831787 2.807145833969116
1.2761552333831787 0.5332276821136475 2.01908278465271
0.5332276821136475 0.5332276821136475 0.5332276821136475
2.01908278465271 1.2752189636230469 2.762946605682373
1.6476190090179443 0.5332276821136475 2.762010335922241
2.807145833969116 1.6155763864517212 3.9987152814865112
1.6155763864517212 0.5332276821136475 2.697925090789795
3.9987152814865112 2.425417423248291 5.5720131397247314
2.425417423248291 1.6155763864517212 3.235258460044861
5.5720131397247314 2.7234790325164795 8.420547246932983
2.7234790325164795 1.6155763864517212 3.831381678581238
8.420547246932983 3.7050085067749023 13.136085987091064
6.996280193328857 3.7050085067749023 10.287551879882812
4.785364210605621 2.7234790325164795 6.847249388694763
3.4029305577278137 2.0416505336761475 4.76421058177948
5.368366479873657 2.7234790325164795 8.013253927230835
4.53668749332428 2.7234790325164795 6.34989595413208


IndexError: list index out of range

In [None]:
server.reset_model()
extra_info['multiplier'] = 1
extra_info['feat_value'] = feat_value
server.reconfigure_model('cls_attack', extra_info=extra_info)
server.reconfigure_model('feature_attack', extra_info=extra_info)

server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# [(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
tmp_share_data = copy.deepcopy(shared_data)
tmp_share_data['metadata']['num_data_points'] = how_many_rec
tmp_share_data['metadata']['labels'] = shared_data['metadata']['labels'][which_to_recover:(which_to_recover + how_many_rec)]
tmp_share_data['gradients'] = recovered_single_gradients[which_to_recover]
tmp_true_user_data = {}
tmp_true_user_data['data'] = true_user_data['data'][which_to_recover:(which_to_recover + how_many_rec)]
tmp_true_user_data['labels'] = true_user_data['labels'][which_to_recover:(which_to_recover + how_many_rec)]

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [tmp_share_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, tmp_true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
for i in range(len(recovered_single_gradients)):
    grad_0 = torch.cat([torch.flatten(i) for i in recovered_single_gradients[i]])
    grad_0_0 = single_gradients[i]
    
    print(float(torch.nn.CosineSimilarity(dim=0)(grad_0, grad_0_0).detach()))

In [35]:
server.reset_model()
extra_info['multiplier'] = 300
extra_info['feat_value'] = 6.847249388694763
server.reconfigure_model('cls_attack', extra_info=extra_info)
server.reconfigure_model('feature_attack', extra_info=extra_info)
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload) 

# reorder the gradients
attacker.objective.cfg_impl = cfg.attack.impl
single_gradients, single_losses = server.cal_single_gradients(attacker, true_user_data)
server.print_gradients_norm(single_gradients, single_losses)

grad norm         loss
67619.1328125 138.51657104492188
39547.1875 872.0106811523438
19662.994140625 1900.1875
27306.884765625 1454.4305419921875
26.657997131347656 0.0003480305604171008
42842.28515625 1064.194091796875
33386.9453125 1371.830810546875
46745.5390625 796.0202026367188


In [34]:
server.all_feat_value

[1.2761552333831787,
 2.0416505336761475,
 2.807145833969116,
 3.7050085067749023,
 3.9987152814865112,
 5.5720131397247314]

In [None]:
3.7050085067749023 * 2 - 3.9987152814865112

In [22]:
a = server.all_feat_value

In [26]:
np.mean(a)

3.233448088169098

In [27]:
3.7050085067749023 * 7 - sum(a)

6.534371018409729

In [16]:
(3.9987152814865112 + 6.5)/2

5.249357640743256