# TAG: Gradient Attack on Transformer-based Language Models

This notebook shows an example for a **short sentence gradient inversion** as described in "TAG: Gradient Attack on Transformer-based Language Models". The setting is a BERT-base model and the federated learning algorithm is **fedSGD**.

Paper URL: https://aclanthology.org/2021.findings-emnlp.305/

#### Abstract
Although distributed learning has increasingly gained attention in terms of effectively utilizing local devices for data privacy enhancement, recent studies show that publicly shared gradients in the training process can reveal the private training data (gradient leakage) to a third-party. We have, however, no systematic understanding of the gradient leakage mechanism on the Transformer based language models. In this paper, as the first attempt, we formulate the gradient attack problem on the Transformer-based language models and propose a gradient attack algorithm, TAG, to reconstruct the local training data. Experimental results on Transformer, TinyBERT4, TinyBERT6 BERT_BASE, and BERT_LARGE using GLUE benchmark show that compared with DLG, TAG works well on more weight distributions in reconstructing training data and achieves 1.5x recover rate and 2.5x ROUGE-2 over prior methods without the need of ground truth label. TAG can obtain up to 90% data by attacking gradients in CoLA dataset. In addition, TAG is stronger than previous approaches on larger models, smaller dictionary size, and smaller input length. We hope the proposed TAG will shed some light on the privacy leakage problem in Transformer-based NLP models.

### Startup

In [1]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [2]:
cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

Investigating use case causal_lang_training with server type honest_but_curious.


{'device': device(type='cuda'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [3]:
cfg.case.user.num_data_points = 1 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [16] # This is the sequence length

# cfg.attack.optim.max_iterations = 12000 # Increasing the number of iterations can help this attack

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [4]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Model architecture transformer3 loaded with 10,800,433 parameters and 0 buffers.
Overall this is a data ratio of  675027:1 for target shape [1, 16] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 1

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: wikitext
    user: 1
    
        
Server (of type HonestServer) with settings:
    Threat model: Honest-but-curious
    Number of planned queries: 1
    Has external/public data: False

    Model:
        model specification: transformer3
        model state: default
        

    Secrets: {}
    
Attacker (of type OptimizationJointAttacker) with settings:
    Hyperparameter Template: tag

    Objective: Tag loss with scale=1.0, weight scheme linear, L1

### Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis.

In [5]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Computing user update in model mode: eval.


In [6]:
user.print(true_user_data)

 The Tower Building of the Little Rock Arsenal, also known as U.S.


### Reconstruct user data:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

You can interrupt the computation early to see a partial solution.

In [7]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], {}, dryrun=cfg.dryrun)

| It: 1 | Rec. loss: 82.7034 |  Task loss: 10.9869 | T: 0.04s |  Label Entropy: 0.9998.
| It: 101 | Rec. loss: 57.0354 |  Task loss: 10.9823 | T: 12.54s |  Label Entropy: 0.8390.
| It: 201 | Rec. loss: 27.3703 |  Task loss: 11.0093 | T: 3.49s |  Label Entropy: 0.1865.
| It: 301 | Rec. loss: 22.8579 |  Task loss: 11.0011 | T: 3.57s |  Label Entropy: 0.1487.
| It: 401 | Rec. loss: 21.0937 |  Task loss: 11.0083 | T: 3.51s |  Label Entropy: 0.1369.
| It: 501 | Rec. loss: 20.4863 |  Task loss: 11.0221 | T: 3.50s |  Label Entropy: 0.1334.
| It: 601 | Rec. loss: 19.9562 |  Task loss: 11.0169 | T: 3.53s |  Label Entropy: 0.1306.
| It: 701 | Rec. loss: 19.5584 |  Task loss: 11.0229 | T: 3.57s |  Label Entropy: 0.1271.
| It: 801 | Rec. loss: 19.3122 |  Task loss: 11.0220 | T: 3.52s |  Label Entropy: 0.1266.
| It: 901 | Rec. loss: 19.0028 |  Task loss: 11.0192 | T: 3.58s |  Label Entropy: 0.1254.
| It: 1000 | Rec. loss: 18.6502 |  Task loss: 11.0211 | T: 3.45s |  Label Entropy: 0.1229.
Optimal ca

Next we'll evaluate metrics, comparing the `reconstructed_user_data` to the `true_user_data`.

In [8]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

METRICS: | Accuracy: 0.8750 | S-BLEU: 0.68 | FMSE: 1.2835e-02 | 
 G-BLEU: 0.61 | ROUGE1: 0.80| ROUGE2: 0.70 | ROUGE-L: 0.80| Token Acc: 87.50% | Label Acc: 0.00%


And finally, we also print the reconstructed data:

In [9]:
user.print(reconstructed_user_data)

 The Tower Building adapted thechant Rock Arsenal, also known as U.S.


### Notes:
* Sentence classification is a better scenario for TAG than e.g. next-token prediction. This is because the attack has to recover the label in addition to the input sentence. For COLA, this is just a binary choice, but for a next-token prediction, the "label" space is the entire vocabulary.
* `huggingface` needs an internet connection for metrics, datasets and tokenizers. After caching these objects, it can be turned to offline mode with `cfg.case.impl.enable_huggingface_offline_mode=True`