# Decepticons: Corrupted Transformers Breach Privacy in Federated Learning for Language Models

This notebook shows an example for the threat model and attack described in "Decepticons: Corrupted Transformers Breach Privacy in Federated Learning for Language Models
". This example deviates from the other "honest-but-curious" server models and investigates a malicious server that may send malicious server updates. The attack succeeds for a range of common transformer architectures and works merely by sending a single malicious query to the user model.

In this notebook, we attack the commonly used BERT model (`bert-base-uncased` from the huggingface implementation).



Paper URL: https://arxiv.org/abs/2201.12675

### Abstract:
A central tenet of Federated learning (FL), which trains models without centralizing user data, is privacy. However, previous work has shown that the gradient updates used in FL can leak user information. While the most industrial uses of FL are for text applications (e.g. keystroke prediction), nearly all attacks on FL privacy have focused on simple image classifiers. We propose a novel attack that reveals private user text by deploying malicious parameter vectors, and which succeeds even with mini-batches, multiple users, and long sequences. Unlike previous attacks on FL, the attack exploits characteristics of both the Transformer architecture and the token embedding, separately extracting tokens and positional embeddings to retrieve high-fidelity text. This work suggests that FL on text, which has historically been resistant to privacy attacks, is far more vulnerable than previously thought.

### Startup

In [1]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [2]:
cfg = breaching.get_config(overrides=["attack=decepticon", "case=9_bert_training", 
                                     "case/server=malicious-transformer"])
          
device = torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

Investigating use case bert_training with server type malicious_transformer_parameters.


{'device': device(type='cpu'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [3]:
cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [512] # This is the sequence length

cfg.case.server.provide_public_buffers = True # Send server signal to disable dropout
cfg.case.server.has_external_data = True  # Not strictly necessary, but could also use random text (see Appendix)
cfg.case.data.tokenizer = "bert-base-uncased"
cfg.case.model = "bert-base-uncased" # Could also choose "bert-sanity-check" which contains ReLU activations
cfg.case.server.pretrained = False


## Attack hyperparameters:

# Server side:
cfg.case.server.param_modification.reset_embedding=True
cfg.case.server.param_modification.v_length = 32 # Length of the sentence component
cfg.case.server.param_modification.eps = 1e-8
cfg.case.server.param_modification.measurement_scale=1e8 # Circumvent GELU
cfg.case.server.param_modification.imprint_sentence_position = 0
cfg.case.server.param_modification.softmax_skew = 1e8
cfg.case.server.param_modification.sequence_token_weight = 1


# Attacker side:

# this option requires installation of `k-means-constrained` which can be tricky:
# If this doesn't work for you, falling back to "dynamic-threshold" is still a decent option.
cfg.attack.sentence_algorithm = "k-means" 
cfg.attack.token_strategy="embedding-norm" # can also do "mixed" for BERT
cfg.attack.embedding_token_weight=0.25 # This can improve performance slightly for long sequences

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [4]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Model architecture bert-base-uncased loaded with 109,514,298 parameters and 1,024 buffers.
Overall this is a data ratio of  213895:1 for target shape [1, 512] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 1

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: wikitext
    user: 1
    
        
Server (of type MaliciousTransformerServer) with settings:
    Threat model: Malicious (Parameters)
    Number of planned queries: 1
    Has external/public data: True

    Model:
        model specification: bert-base-uncased
        model state: default
        public b

### Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis.

In [5]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  

Found attention of shape torch.Size([768, 768]).
Found attention of shape torch.Size([768, 768]).
Computing feature distribution before the probe layer Linear(in_features=768, out_features=3072, bias=True) from external data.
Feature mean is 8976919.0, feature std is 103903328.0.
Computing user update on user 1 in model mode: eval.


In [6]:
user.print(true_user_data)

[CLS] the tower building of the little rock arsenal transformed also known as u. s. arsenal building [MASK] is [MASK] building located in macarthur park [unused573] downtown little rock, arkansas. built in 1840, it was part of little rocklar s first [MASK] installation. since its decommissioning, [MASK] tower building has housed two museums. it was home to the arkansas museum of natural history and navigate from 1942 to 1997 [MASK] the macarthur museum of arkansas military history since 2001. [MASK] has also been the headquarters of the little rock æsthetic club since 1894. [SEP] [CLS] the building receives [MASK] name [MASK] [MASK] distinct octagonal tower. besides [MASK] the last remaining structure of the original [MASK] rock arsenal and one of the oldest buildings in central arkansas, it was [MASK] the birthplace of general douglas macarthur, who became the supreme commander of us forces in the [MASK] [MASK] during world war ii. it was also the starting place of the camden expediti

### Reconstruct user data:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

For this attack, we also share secret information from the malicious server with the attack (`server.secrets`), which here is the location and structure of the imprint block.

In [7]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], server.secrets, 
                                                      dryrun=cfg.dryrun)

Recovered tokens tensor([   27,    96,   101,   102,   103,   138,   202,   425,   468,   532,
          578,   665,   670,   808,   896,   991,  1000,  1002,  1005,  1006,
         1007,  1010,  1011,  1012,  1014,  1017,  1026,  1028,  1030,  1037,
         1038,  1049,  1055,  1057,  1097,  1360,  1554,  1765,  1865,  1996,
         1997,  1998,  1999,  2000,  2001,  2003,  2004,  2005,  2006,  2007,
         2009,  2011,  2012,  2013,  2017,  2018,  2020,  2021,  2022,  2028,
         2029,  2034,  2036,  2037,  2038,  2040,  2042,  2048,  2049,  2050,
         2069,  2075,  2076,  2083,  2086,  2088,  2101,  2104,  2108,  2109,
         2110,  2112,  2116,  2117,  2124,  2126,  2137,  2142,  2144,  2147,
         2148,  2149,  2150,  2162,  2163,  2166,  2170,  2171,  2173,  2174,
         2184,  2188,  2194,  2195,  2197,  2198,  2199,  2210,  2211,  2231,
         2236,  2243,  2249,  2252,  2261,  2273,  2274,  2281,  2284,  2308,
         2311,  2315,  2327,  2328,  2332,  233

Next we'll evaluate metrics, comparing the `reconstructed_user_data` to the `true_user_data`.

In [8]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

Starting evaluations for attack effectiveness report...
Using default tokenizer.
METRICS: | Accuracy: 0.9102 | S-BLEU: 0.82 | FMSE: 1.0441e+01 | 
 G-BLEU: 0.79 | ROUGE1: 0.97| ROUGE2: 0.85 | ROUGE-L: 0.93| Token Acc T:95.51%/A:96.35% | Label Acc: 52.44%


And finally, we also plot the reconstructed data:

In [9]:
user.print(reconstructed_user_data)

[CLS] the tower building of the little rock arsenal transformed also known as u. s. arsenal building [MASK] is [MASK] building located in macarthur park [unused573] downtown little rock life arkansas. established in 1840 tower it was part of little rocklar s first [MASK] installation. since its decommissioning handful [MASK] and building has housed two museums. it was home to the of museum of natural history and navigate from 1942 to 1997 [MASK] the macarthur museum of arkansas military history since 2001. [MASK] has also been the headquarters of the little rock æsthetic club since 1894. [SEP] [SEP] the building receives [MASK] name [MASK] [MASK] distinct arkansas tower. besides [MASK] the world remaining structure of the original [MASK] rock arsenal and one of the oldest buildings in central arkansas who it was [MASK] referred birthplace of general douglas macarthur, who. the supreme commander of us forces in the [MASK] [MASK] during place war ii. it was also the starting also of the 

In [11]:
user.print_with_confidence(reconstructed_user_data)

[48;5;184m[CLS] [0m[48;5;184mthe [0m[48;5;184mtower [0m[48;5;184mbuilding [0m[48;5;184mof [0m[48;5;184mthe [0m[48;5;184mlittle [0m[48;5;184mrock [0m[48;5;184marsenal [0m[48;5;184mtransformed [0m[48;5;184malso [0m[48;5;184mknown [0m[48;5;184mas [0m[48;5;184mu [0m[48;5;184m. [0m[48;5;184ms [0m[48;5;184m. [0m[48;5;184marsenal [0m[48;5;184mbuilding [0m[48;5;184m[MASK] [0m[48;5;184mis [0m[48;5;184m[MASK] [0m[48;5;184mbuilding [0m[48;5;184mlocated [0m[48;5;184min [0m[48;5;184mmacarthur [0m[48;5;184mpark [0m[48;5;184m[unused573] [0m[48;5;184mdowntown [0m[48;5;184mlittle [0m[48;5;184mrock [0m[48;5;178mlife [0m[48;5;184markansas [0m[48;5;184m. [0m[48;5;178mestablished [0m[48;5;184min [0m[48;5;184m1840 [0m[48;5;178mtower [0m[48;5;184mit [0m[48;5;184mwas [0m[48;5;184mpart [0m[48;5;184mof [0m[48;5;184mlittle [0m[48;5;184mrock [0m[48;5;184m##lar [0m[48;5;184ms [0m[48;5;184mfirst [0m[48;5;184m[MASK] [0m[

In [12]:
user.print_and_mark_correct(reconstructed_user_data, true_user_data)

[48;5;190m[CLS] [0m[48;5;190mthe [0m[48;5;190mtower [0m[48;5;190mbuilding [0m[48;5;190mof [0m[48;5;190mthe [0m[48;5;190mlittle [0m[48;5;190mrock [0m[48;5;190marsenal [0m[48;5;190mtransformed [0m[48;5;190malso [0m[48;5;190mknown [0m[48;5;190mas [0m[48;5;190mu [0m[48;5;190m. [0m[48;5;190ms [0m[48;5;190m. [0m[48;5;190marsenal [0m[48;5;190mbuilding [0m[48;5;190m[MASK] [0m[48;5;190mis [0m[48;5;190m[MASK] [0m[48;5;190mbuilding [0m[48;5;190mlocated [0m[48;5;190min [0m[48;5;190mmacarthur [0m[48;5;190mpark [0m[48;5;190m[unused573] [0m[48;5;190mdowntown [0m[48;5;190mlittle [0m[48;5;190mrock [0m[48;5;160mlife [0m[48;5;190markansas [0m[48;5;190m. [0m[48;5;160mestablished [0m[48;5;190min [0m[48;5;190m1840 [0m[48;5;160mtower [0m[48;5;190mit [0m[48;5;190mwas [0m[48;5;190mpart [0m[48;5;190mof [0m[48;5;190mlittle [0m[48;5;190mrock [0m[48;5;190m##lar [0m[48;5;190ms [0m[48;5;190mfirst [0m[48;5;190m[MASK] [0m[

### Notes:
* There are a variety of hyperparameters to the attack which are set to reasonable defaults. Performance of the attack could be improved in some unusual use cases (datasets or models) by tuning these parameters further.
* In this example, dropout is disabled under the assumption that this is a parameter that can be controlled in the server update. The optimal attack simply disables dropout. However, the attack can still succeed when dropout is enforced by the user, albeit with a minor loss in reconstruction quality.
* This example also assumes complete freedom to choose the parameter vector, for this reason we circumvent the smooth part of the GELU activation with a "very" large measurement vector magnitude. This is arguably excessive for only a small gain in accuracy. A similar argument can be made for the default softmax skew value.
* We also want to re-emphasize that the design space of these parameter modification attacks is large. A defense against the specific parameter modification described here is unlikely to be safe in general!
* Token recovery is much easier when the embedding is randomly initialized. Here we explicitely re-initialize the BERT embedding to improve label accuracy.
* The embedding gradient is not strictly necessary for the attack. The attack can also be run with `token_strategy=None`, in which case the embedding gradient is entirely disregarded (and does not have to be learnable).
* In our setup it made the most sense to consider the `[MASK]` tokens as part of the input (and this is also how the metrics are scored). However, the contents of the masked values are actually also leaked and are retrievable from the gradient of the decoder bias.