# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [1]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [2]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=["case/data=wikitext", "case/server=malicious-model-rtf",
                                                      "case.model=transformer3",
                                                      "attack=imprint"])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

Investigating use case single_imagenet with server type malicious_model.


{'device': device(type='cpu'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [6]:
cfg.case.user.num_data_points = 128 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [32] # This is the sequence length

cfg.case.server.model_modification.num_bins = 2048
cfg.case.server.model_modification.position = None # '4.0.conv'
cfg.case.server.model_modification.linfunc = 'randn'

cfg.case.server.has_external_data = False
cfg.case.data.tokenizer = "gpt2"

### Instantiate all parties

In [7]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

First layer determined to be pos_encoder
Block inserted at feature shape torch.Size([32, 96]).
Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)
Model architecture transformer3 loaded with 23,388,465 parameters and 0 buffers.
Overall this is a data ratio of    5710:1 for target shape [128, 32] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 128

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: wikitext
    user: 1
    
        
Server (of type MaliciousModelServer) with settings:
    Threat model: Malicious (Analyst)
    Number of planned queries: 1
    Has external/public data: False

    Model:
        model specification: transformer3
        model state: default
        

    Secrets: {'ImprintBlock': {'weight_idx': 0, 'bias_idx': 1, 'shape':

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [8]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Computing user update in model mode: eval.


In [9]:
# user.print(true_user_data)

# Reconstruct user data

In [10]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
# user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

Recovered tokens tensor([[   11,    12,    13,  ...,   284,   285,   286],
        [  287,   290,   291,  ...,   370,   371,   373],
        [  376,   379,   382,  ...,   513,   517,   530],
        ...,
        [   11,    32,   110,  ..., 28139, 40926, 43084],
        [  314,   327,   351,  ..., 16063, 24375, 37637],
        [   13,    29,    31,  ..., 12877, 41075, 41601]]) through strategy decoder-bias.
METRICS: | Accuracy: 0.9531 | S-BLEU: 0.95 | FMSE: 9.9912e-08 | 
 G-BLEU: 0.95 | ROUGE1: 0.96| ROUGE2: 0.95 | ROUGE-L: 0.95| Token Acc: 95.68% | Label Acc: 0.00%
