# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

In [None]:
import numpy as np
from scipy.optimize import linear_sum_assignment

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=["case/data=wikitext", "case/server=malicious-transformer",
                                                      "case.model=gpt2S",
                                                      "attack=decepticon"])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 2 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [4] # This is the sequence length

cfg.case.data.tokenizer = "gpt2"

cfg.case.server.has_external_data = True

cfg.case.server.param_modification.v_length = 64
cfg.case.server.param_modification.imprint_sentence_position = 0
cfg.case.server.param_modification.softmax_skew = 10000000
cfg.case.server.param_modification.sequence_token_weight = 1

cfg.case.server.param_modification.eps = 1e-6

cfg.case.server.pretrained=False

cfg.attack.token_strategy ="embedding-norm"

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
user.print(true_user_data)

## Run through the initial transformer blocks "by hand":

In [None]:
model = user.model
embedding = model.model.transformer.wte
pos_encoder = model.model.transformer.wpe

norm_layer0 = model.model.transformer.h[0].ln_1
norm_layer1 = model.model.transformer.h[0].ln_2

attention_layer = dict()
attention_layer["in_proj_weight"] = model.model.transformer.h[0].attn.c_attn.weight
attention_layer["in_proj_bias"] = model.model.transformer.h[0].attn.c_attn.bias
attention_layer["out_proj_weight"] = model.model.transformer.h[0].attn.c_proj.weight
attention_layer["out_proj_bias"] = model.model.transformer.h[0].attn.c_proj.bias

first_linear_layers, second_linear_layers, unused_mhas = [], [], []  # collecting all the imprint layers
for i, layer in enumerate(model.model.transformer.h):
    first_linear_layers.append(layer.mlp.c_fc)
    second_linear_layers.append(layer.mlp.c_proj)
    if i != 0:
        unused_mhas.append(layer.attn.c_proj)


hidden_dim, embedding_dim = first_linear_layers[0].weight.T.shape
ff_transposed = True

In [None]:
inputs = true_user_data["data"]

In [None]:
trafo_inputs = pos_encoder(torch.arange(inputs.shape[1])[None, :]) + embedding(true_user_data["data"])
trafo_inputs.shape

In [None]:
attn_outputs, (K, V), attn_weights = model.model.transformer.h[0].attn(trafo_inputs, 
                                                                       output_attentions=True, use_cache=True)

In [None]:
attn_outputs[0, :, :8]

In [None]:
V.shape

In [None]:
attn_weights[0]

In [None]:
np.corrcoef(attn_outputs.reshape(-1, 768).detach())[0]

In [None]:
residuals = attn_outputs + trafo_inputs
linear_inputs = norm_layer1(residuals)
linear_inputs

In [None]:
linear_inputs[0, :, 0:8]

In [None]:
linear_inputs.shape

### Simulate breached features

In [None]:
permutation = torch.randperm(32) # torch.randperm(32) # torch.arange(32)
num_breached_embeddings = 20
reverse_perm = torch.argsort(permutation[:num_breached_embeddings])
permutation

In [None]:
seq_features = linear_inputs.permute(0, 1, 2).reshape(-1, 96)[:, :8][permutation][:num_breached_embeddings]
seq_features.shape

In [None]:
corrs = torch.as_tensor(np.corrcoef(seq_features.detach()))

# Reconstruct user data

In [None]:
attacker.cfg.sentence_algorithm = "dynamic-threshold" # "k-means"

In [None]:
user.print(true_user_data)

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)
user.print(reconstructed_user_data)

In [None]:
metrics

# Manually compute attention

In [None]:

inputs = user.model.pos_encoder(user.model.encoder(inputs))
inputs[0]

In [None]:
Q = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[:96, :]
K = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[96:192, :]
V = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[192:, :]
q_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[:96]
k_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[96:192]
v_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[192:]

O =  user.model.transformer_encoder.layers[0].self_attn.out_proj.weight.data

In [None]:
K

In [None]:
self_attn = user.model.transformer_encoder.layers[0].self_attn

In [None]:
self_attn.batch_first = True

In [None]:
Q.shape, inputs[0].T.shape, V.shape, K.shape, q_b.shape

In [None]:
inputs.shape

In [None]:
Qv = ((Q@inputs[0].T).T + q_b)
Kv = ((K@inputs[0].T).T + k_b)
Vv = ((V@inputs[0].T).T + v_b)

In [None]:
M = (Qv.reshape(16, 8, 12) @ Kv.reshape(16, 8, 12).T).softmax(dim=-1)

In [None]:
M.shape

In [None]:
attn_map = torch.zeros(16, 16)
for head in range(8):
    mapp = (Qv.reshape(16, 8, 12)[:, head, :] @ Kv.reshape(16, 8, 12)[:, head, :].T).softmax(dim=-1)
    attn_map += mapp
    print(mapp)

In [None]:
Qv.reshape(16, 8, 12)[0, 0, :]

In [None]:
Kv.reshape(16, 8, 12)[:, 0, :].T

In [None]:
Qv @ Kv.T

In [None]:
Vv.reshape(16, 8, 12)[0, 0, :]

In [None]:
(((Qv.reshape(16, 8, 12)[:, head, :] @ Kv.reshape(16, 8, 12)[:, head, :].T).softmax(dim=-1) @ Vv.reshape(16, 8, 12)[:, head, :])).shape

In [None]:
(Qv.reshape(16, 8, 12)[:, head, :] @ Kv.reshape(16, 8, 12)[:, head, :].T).softmax(dim=-1).shape

In [None]:
outputs, attn_outputs = self_attn(inputs, inputs, inputs)

In [None]:
attn_outputs[0]

In [None]:
user.model.transformer_encoder.layers[0].norm1(outputs + inputs)

In [None]:
normy = torch.nn.LayerNorm(4)

In [None]:
a = torch.tensor([[1, 2, 1, 2], [5, 5, 5, 5], [7, 7,7, 7]]).float()
a

In [None]:
normy(a[None])

In [None]:
model.in_proj_bias[:96].pe[0]

In [None]:
outputs, attn_weights = model(inputs, inputs, inputs)

In [None]:
attn_weights

In [None]:
reconstructed_user_data, stats = attacker.reconstruct_single_sentence([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
reconstructed_user_data, stats = attacker.reconstruct2([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
permuted_true_data = dict(data=true_user_data["data"][[3, 2, 1, 0]], labels=true_user_data["labels"])

In [None]:
metrics = breaching.analysis.report(permuted_true_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
metrics = breaching.analysis.report(true_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)