# TAG: Gradient Attack on Transformer-based Language Models

This notebook shows an example for a **short sentence gradient inversion** as described in "TAG: Gradient Attack on Transformer-based Language Models". The setting is a BERT-base model and the federated learning algorithm is **fedSGD**.

Paper URL: https://aclanthology.org/2021.findings-emnlp.305/

#### Abstract
Although distributed learning has increasingly gained attention in terms of effectively utilizing local devices for data privacy enhancement, recent studies show that publicly shared gradients in the training process can reveal the private training data (gradient leakage) to a third-party. We have, however, no systematic understanding of the gradient leakage mechanism on the Transformer based language models. In this paper, as the first attempt, we formulate the gradient attack problem on the Transformer-based language models and propose a gradient attack algorithm, TAG, to reconstruct the local training data. Experimental results on Transformer, TinyBERT4, TinyBERT6 BERT_BASE, and BERT_LARGE using GLUE benchmark show that compared with DLG, TAG works well on more weight distributions in reconstructing training data and achieves 1.5x recover rate and 2.5x ROUGE-2 over prior methods without the need of ground truth label. TAG can obtain up to 90% data by attacking gradients in CoLA dataset. In addition, TAG is stronger than previous approaches on larger models, smaller dictionary size, and smaller input length. We hope the proposed TAG will shed some light on the privacy leakage problem in Transformer-based NLP models.

### Startup

In [1]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [2]:
cfg = breaching.get_config(overrides=["case=9_bert_training", "case/data=cola", "case.data.task=classification",
                                      "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

Investigating use case bert_training with server type honest_but_curious.


{'device': device(type='cuda'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [23]:
cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [16] # This is the sequence length

cfg.case.model="bert-sanity-check"

cfg.attack.optim.max_iterations = 12000

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [24]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Reusing dataset glue (/home/jonas/data/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Model architecture bert-sanity-check loaded with 109,483,778 parameters and 1,024 buffers.
Overall this is a data ratio of 6842736:1 for target shape [1, 16] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 1

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: cola
    user: 1
    
        
Server (of type HonestServer) with settings:
    Threat model: Honest-but-curious
    Number of planned queries: 1
    Has external/public data: False

    Model:
        model specification: bert-sanity-check
        model state: default
        public buffers: True

    Secrets: {}
    
Attacker (of type OptimizationJointAttacker) with settings:
    Hyperparameter Template: tag

    Objective: Tag loss with scale=1.0, weight schem

### Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis.

In [25]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Computing user update in model mode: eval.


In [26]:
user.print(true_user_data)

[CLS] one more pseudo generalization and i'm giving up. [SEP] [PAD] [PAD]


### Reconstruct user data:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

You can interrupt the computation early to see a partial solution.

In [27]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], {}, dryrun=cfg.dryrun)

| It: 1 | Rec. loss: 95.2092 |  Task loss: 0.6980 | T: 0.23s |  Label Entropy: 0.9979.
| It: 101 | Rec. loss: 27.7278 |  Task loss: 0.8908 | T: 18.31s |  Label Entropy: 0.2830.
| It: 201 | Rec. loss: 22.1621 |  Task loss: 0.8842 | T: 19.68s |  Label Entropy: 0.1864.
| It: 301 | Rec. loss: 19.5941 |  Task loss: 0.8836 | T: 20.15s |  Label Entropy: 0.1639.
| It: 401 | Rec. loss: 19.5519 |  Task loss: 0.8839 | T: 11.45s |  Label Entropy: 0.1755.
| It: 501 | Rec. loss: 19.2239 |  Task loss: 0.8855 | T: 11.45s |  Label Entropy: 0.1721.
| It: 601 | Rec. loss: 18.9395 |  Task loss: 0.8847 | T: 11.34s |  Label Entropy: 0.1671.
| It: 701 | Rec. loss: 19.3074 |  Task loss: 0.8886 | T: 11.62s |  Label Entropy: 0.1712.
| It: 801 | Rec. loss: 18.7212 |  Task loss: 0.8857 | T: 11.54s |  Label Entropy: 0.1664.
| It: 901 | Rec. loss: 17.2360 |  Task loss: 0.8819 | T: 11.69s |  Label Entropy: 0.1611.
| It: 1001 | Rec. loss: 17.2748 |  Task loss: 0.8853 | T: 11.68s |  Label Entropy: 0.1539.
| It: 1101 |

| It: 9101 | Rec. loss: 16.1902 |  Task loss: 0.8859 | T: 10.41s |  Label Entropy: 0.1519.
| It: 9201 | Rec. loss: 16.5991 |  Task loss: 0.8865 | T: 10.19s |  Label Entropy: 0.1531.
| It: 9301 | Rec. loss: 16.2550 |  Task loss: 0.8864 | T: 10.30s |  Label Entropy: 0.1545.
| It: 9401 | Rec. loss: 16.0714 |  Task loss: 0.8860 | T: 10.29s |  Label Entropy: 0.1539.
| It: 9501 | Rec. loss: 16.2131 |  Task loss: 0.8857 | T: 10.04s |  Label Entropy: 0.1543.
| It: 9601 | Rec. loss: 16.4635 |  Task loss: 0.8862 | T: 10.13s |  Label Entropy: 0.1534.
| It: 9701 | Rec. loss: 16.4038 |  Task loss: 0.8867 | T: 10.28s |  Label Entropy: 0.1535.
| It: 9801 | Rec. loss: 16.4264 |  Task loss: 0.8866 | T: 10.45s |  Label Entropy: 0.1531.
| It: 9901 | Rec. loss: 16.7151 |  Task loss: 0.8868 | T: 10.24s |  Label Entropy: 0.1528.
| It: 10001 | Rec. loss: 16.4080 |  Task loss: 0.8857 | T: 10.25s |  Label Entropy: 0.1544.
| It: 10101 | Rec. loss: 15.9016 |  Task loss: 0.8857 | T: 10.21s |  Label Entropy: 0.151

Next we'll evaluate metrics, comparing the `reconstructed_user_data` to the `true_user_data`.

In [28]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

METRICS: | Accuracy: 0.0625 | S-BLEU: 0.15 | FMSE: 5.8470e-03 | 
 G-BLEU: 0.22 | ROUGE1: 0.77| ROUGE2: 0.25 | ROUGE-L: 0.46| Token Acc: 75.00% | Label Acc: 0.00%


And finally, we also print the reconstructed data:

In [29]:
user.print(reconstructed_user_data)

[SEP] generalization and [SEP] pseudo i giving up [CLS] 20 one. 陽 more [SEP]


### Notes:
* Sentence classification is a better scenario for TAG than e.g. next-token prediction. This is because the attack has to recover the label in addition to the input sentence. For COLA, this is just a binary choice, but for a next-token prediction, the "label" space is the entire vocabulary.