# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [1]:
import torch
import numpy as np
import hydra
import copy
from omegaconf import OmegaConf, open_dict
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [2]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=clsattack'])
    cfg.case.server.name = 'class_malicious_parameters'
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
Investigating use case single_imagenet with server type class_malicious_parameters.


{'device': device(type='cuda'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [15]:
with open_dict(cfg):
    cfg.case.data.partition = "unique-class"
    cfg.case.user.user_idx = 25
    cfg.case.user.num_data_points = 16
    cfg.case.model = 'resnet18'
    cfg.case.server.model_state = 'trained'
    cfg.case.user.provide_labels = True
    cfg.case.user.provide_buffers = True
    cfg.case.user.provide_num_data_points = True

    cfg.attack.optim.max_iterations = 24000

### Instantiate all parties

In [16]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Model architecture resnet18 loaded with 11,380,173 parameters and 9,620 buffers.
Overall this is a data ratio of       5:1 for target shape [16, 3, 224, 224] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 16

    Threat model:
    User provides labels: True
    User provides buffers: True
    User provides number of data points: True

    Data:
    Dataset: ImageNetAnimals
    user: 25
    
        
Server (of type ClassParameterServer) with settings:
    Threat model: Malicious (Parameters)
    Number of planned queries: 1
    Has external/public data: False

    Model:
        model specification: resnet18
        model state: trained
        public buffers: True

    Secrets: {}
    
Attacker (of type OptimizationBasedAttacker) with settings:
    Hyperparameter Template: invertinggradients

    Objective: Cosine Similarity with scale=1.0 and task reg=0.0
    Regularizers: Total Variation, scale=0.2. p=2 q=0.5. Color TV: double opppo

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [19]:
# modify the model to catch the class we want
which_to_recover = 0
how_many = 1
how_many_rec = 1

# preparing for the attack
server.reset_model()
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

t_labels = true_user_data['labels'].cpu().detach().numpy()
cls_to_obtain = t_labels[which_to_recover:(which_to_recover + how_many)]
extra_info = {'cls_to_obtain': cls_to_obtain}
all_feature = torch.flatten(server.reconstruct_feature(shared_data, cls_to_obtain))

server.reset_model()
server.reconfigure_model('cls_attack', extra_info=extra_info)
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)
avg_feature = torch.flatten(server.reconstruct_feature(shared_data, cls_to_obtain))
other_feature = avg_feature - all_feature

single_gradient_recovered = False
user.counted_queries = 0

while not single_gradient_recovered:
    feat_to_obtain = int(torch.argmax(avg_feature))
    feat_value = float(avg_feature[feat_to_obtain])

    # binary attack to recover all single gradients
    extra_info["feat_to_obtain"] = feat_to_obtain
    extra_info["feat_value"] = feat_value
    extra_info["multiplier"] = 1
    extra_info["num_target_data"] = int(torch.count_nonzero((shared_data["metadata"]["labels"] == int(cls_to_obtain)).to(int)))
    extra_info["num_data_points"] = int(cfg.case.user.num_data_points)
    
    recovered_single_gradients = server.binary_attack(user, extra_info)
    # recovered_single_gradients = server.one_shot_binary_attack(user, extra_info)
    if recovered_single_gradients is not None:
        single_gradient_recovered = True
    else:
        avg_feature[feat_to_obtain] = -1000

    logger.info(f"Spent {user.counted_queries} user queries so far.")

# reorder the gradients
server.reset_model()
extra_info['multiplier'] = 1
extra_info['feat_value'] = feat_value
extra_info["feat_to_obtain"] = feat_to_obtain
server.reconfigure_model('cls_attack', extra_info=extra_info)
server.reconfigure_model('feature_attack', extra_info=extra_info)
server_payload = server.distribute_payload()

target_indx = np.where(t_labels == cls_to_obtain)
tmp_true_user_data = {}
tmp_true_user_data['data'] = true_user_data['data'][target_indx]
tmp_true_user_data['labels'] = true_user_data['labels'][target_indx]

attacker.objective.cfg_impl = cfg.attack.impl
single_gradients, single_losses = server.cal_single_gradients(attacker, tmp_true_user_data)
recovered_single_gradients = server.order_gradients(recovered_single_gradients, single_gradients)

Too many attempts (256) on this feature!
Spent 256 user queries so far.
Spent 377 user queries so far.


In [20]:
for i in range(len(recovered_single_gradients)):
    grad_0 = torch.cat([torch.flatten(ii) for ii in recovered_single_gradients[i]])
    grad_0_0 = single_gradients[i]
    
    print(float(torch.nn.CosineSimilarity(dim=0)(grad_0, grad_0_0).detach()))

0.978751540184021
0.9852755069732666
0.9795252680778503
0.9870787262916565
0.9832985401153564
0.9819676280021667
0.9787515997886658
0.9743768572807312
0.9806594252586365
0.9841616153717041
0.9756428003311157
0.9761735200881958
0.9753624200820923
0.9849905371665955
0.979377031326294
0.9783847332000732


In [None]:
server.reset_model()
extra_info['multiplier'] = 1
extra_info['feat_value'] = feat_value
server.reconfigure_model('cls_attack', extra_info=extra_info)
server.reconfigure_model('feature_attack', extra_info=extra_info)

server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# [(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
tmp_share_data = copy.deepcopy(shared_data)
tmp_share_data['metadata']['num_data_points'] = how_many_rec
tmp_share_data['metadata']['labels'] = shared_data['metadata']['labels'][which_to_recover:(which_to_recover + how_many_rec)]
tmp_share_data['gradients'] = recovered_single_gradients[which_to_recover]
tmp_true_user_data = {}
tmp_true_user_data['data'] = true_user_data['data'][which_to_recover:(which_to_recover + how_many_rec)]
tmp_true_user_data['labels'] = true_user_data['labels'][which_to_recover:(which_to_recover + how_many_rec)]

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [tmp_share_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, tmp_true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
# server.reset_model()
# extra_info['multiplier'] = 1000
# extra_info["db_flip"] = 1
# extra_info['feat_value'] = server.all_feat_value[1]
# server.reconfigure_model('cls_attack', extra_info=extra_info)
# server.reconfigure_model('feature_attack', extra_info=extra_info)
# server_payload = server.distribute_payload()
# shared_data, true_user_data = user.compute_local_updates(server_payload) 

# attacker.objective.cfg_impl = cfg.attack.impl
# single_gradients, single_losses = server.cal_single_gradients(attacker, true_user_data)
# server.print_gradients_norm(single_gradients, single_losses)

In [None]:
# server.all_feat_value