# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [1]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

In [2]:
import numpy as np
from scipy.optimize import linear_sum_assignment

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [3]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=["case/data=wikitext", "case/server=malicious-transformer",
                                                      "case.model=transformer3p",
                                                      "attack=decepticon"])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

Investigating use case single_imagenet with server type malicious_transformer_parameters.


{'device': device(type='cpu'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [4]:
cfg.case.user.num_data_points = 4 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [16] # This is the sequence length

cfg.case.data.tokenizer = "word-level"

cfg.case.server.has_external_data = True

cfg.case.server.param_modification.v_length = 8
cfg.case.server.param_modification.imprint_sentence_position = 0
cfg.case.server.param_modification.softmax_skew = 100
cfg.case.server.param_modification.sequence_token_weight = 1

### Instantiate all parties

In [5]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)
Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)
Model architecture transformer3p loaded with 10,751,281 parameters and 0 buffers.
Overall this is a data ratio of  167989:1 for target shape [4, 16] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 4

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: wikitext
    user: 1
    
        
Server (of type MaliciousTransformerServer) with settings:
    Threat model: Malicious (Parameters)
    Number of planned queries: 1
    Has external/public data: True

    Model:
        model specification: transformer3p
        model state: default
        

    Secrets: {}
    

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [6]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Computing feature distribution before the linear1 layer from external data.
Feature mean is -0.00920235924422741, feature std is 0.7685369253158569.


In [7]:
user.print(true_user_data)

[CLS] the tower building of the little rock arsenal, also known as u. s
. arsenal building, is a building located in macarthur park in downtown little rock,
arkansas. built in 1 8 4 0, it was part of little rock '
s first military installation. since its decommissioning, the tower building has housed two museums


## Run through the initial transformer blocks "by hand":

In [8]:
inputs = true_user_data["data"]

In [9]:
trafo_inputs = user.model.pos_encoder(user.model.encoder(true_user_data["data"]))#[0, 0, :]
trafo_inputs

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.9314, -1.6871,  0.1320],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.7709,  0.9466, -0.4144],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.6361,  1.3109, -0.1161],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.3998,  0.6305, -0.8165],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.7691,  0.5120, -1.2598],
         [ 0.0000,  0.0000,  0.0000,  ...,  1.2931, -1.1679, -0.8926]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  1.0273, -1.6562,  0.0911],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.6242,  0.9126, -0.4203],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.5350,  1.3296, -0.2147],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.2974,  0.6015, -0.8886],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.7996,  0.5429, -1.3360],
         [ 0.0000,  0.0000,  0.0000,  ...,  1.3342, -1.2882, -0.9508]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  1.0000, -1.5464,  0.0250],
         [ 0.0000,  0.0000,  0.0000,  ..., -0

In [10]:
attn_outputs, attn_weights = user.model.transformer_encoder.layers[0].self_attn(trafo_inputs, trafo_inputs, trafo_inputs)
attn_outputs

tensor([[[-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3244, -0.6605,  1.0048,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3177, -0.5200,  0.9559,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.2305, -0.5085,  1.0406,  ...,  0.0000,  0.0000,  0.0000],
         [-0.2305, -0.5085,  1.0406,  ...,  0

In [11]:
attn_weights

tensor([[[8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12],
         [8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12],
         [8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12],
         ...,
         [8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12],
         [8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12],
         [8.7500e-01, 9.8678e-40, 9.0902e-42,  ..., 1.2116e-09,
          0.0000e+00, 9.7909e-12]],

        [[8.7500e-01, 8.9123e-43, 5.3950e-43,  ..., 4.5055e-11,
          0.0000e+00, 1.6312e-11],
         [8.7500e-01, 8.9123e-43, 5.3950e-43,  ..., 4.5055e-11,
          0.0000e+00, 1.6312e-11],
         [8.7500e-01, 8.9123e-43, 5.3950e-43,  ..., 4.5055e-11,
          0.0000e+00, 1.6312e-11],
         ...,
         [8.7500e-01, 8.9123e-43, 5.3950e-43,  ..., 4.5055e-11,
          0.000

In [12]:
residuals = attn_outputs + trafo_inputs
linear_inputs = user.model.transformer_encoder.layers[0].norm1(residuals)
linear_inputs

tensor([[[-0.3979, -0.7212,  0.8809,  ...,  0.8102, -1.7089,  0.0411],
         [-0.3703, -0.7442,  1.1081,  ..., -0.8670,  1.0433, -0.4705],
         [-0.2083, -0.5361,  1.0881,  ..., -0.5123,  1.3866, -0.0052],
         ...,
         [-0.3324, -0.6513,  0.9286,  ...,  0.3547,  0.5735, -0.7994],
         [-0.3629, -0.6865,  0.9168,  ..., -0.7911,  0.4423, -1.2635],
         [-0.2498, -0.5409,  0.9014,  ...,  1.1511, -0.9804, -0.7419]],

        [[-0.4007, -0.5943,  0.8185,  ...,  0.8869, -1.6821, -0.0094],
         [-0.3722, -0.5984,  1.0520,  ..., -0.7150,  1.0036, -0.4870],
         [-0.2130, -0.4130,  1.0466,  ..., -0.4279,  1.4161, -0.1111],
         ...,
         [-0.3346, -0.5264,  0.8734,  ...,  0.2489,  0.5373, -0.8761],
         [-0.3555, -0.5506,  0.8732,  ..., -0.8203,  0.4748, -1.3378],
         [-0.2448, -0.4208,  0.8630,  ...,  1.1920, -1.0890, -0.7955]],

        [[-0.3190, -0.5871,  0.9066,  ...,  0.8675, -1.5879, -0.0726],
         [-0.2854, -0.5956,  1.1327,  ..., -0

In [13]:
linear_inputs[0, :, 0:8]

tensor([[-0.3979, -0.7212,  0.8809, -2.3001, -0.1344,  0.1542, -0.7872, -1.3975],
        [-0.3703, -0.7442,  1.1081, -2.5695, -0.0657,  0.2680, -0.8204, -1.5260],
        [-0.2083, -0.5361,  1.0881, -2.1366,  0.0589,  0.3514, -0.6029, -1.2217],
        [-0.2831, -0.5859,  0.9144, -2.0643, -0.0363,  0.2339, -0.6476, -1.2192],
        [-0.1551, -0.5184,  1.2817, -2.2925,  0.1410,  0.4653, -0.5925, -1.2783],
        [-0.5559, -0.9293,  0.9212, -2.7531, -0.2515,  0.0819, -1.0055, -1.7105],
        [-0.2768, -0.6090,  1.0370, -2.2310, -0.0060,  0.2905, -0.6767, -1.3038],
        [-0.2696, -0.6062,  1.0614, -2.2495,  0.0047,  0.3050, -0.6748, -1.3101],
        [-0.1975, -0.5337,  1.1320, -2.1751,  0.0765,  0.3765, -0.6022, -1.2368],
        [-0.3093, -0.6608,  1.0809, -2.3771, -0.0228,  0.2909, -0.7324, -1.3960],
        [-0.3099, -0.6189,  0.9121, -2.1277, -0.0581,  0.2177, -0.6819, -1.2652],
        [-0.2177, -0.5382,  1.0499, -2.1034,  0.0435,  0.3296, -0.6036, -1.2087],
        [-0.2800

In [14]:
linear_inputs.shape

torch.Size([4, 16, 96])

### Simulate breached features

In [15]:
permutation = torch.randperm(32) # torch.randperm(32) # torch.arange(32)
num_breached_embeddings = 20
reverse_perm = torch.argsort(permutation[:num_breached_embeddings])
permutation

tensor([13, 27, 11, 25, 19, 16, 26,  0,  5, 23, 12,  3,  8, 29, 15, 21, 31, 30,
         6, 20, 17,  2, 22, 18, 24,  4,  1,  7, 14,  9, 10, 28])

In [16]:
seq_features = linear_inputs.permute(0, 1, 2).reshape(-1, 96)[:, :8][permutation][:num_breached_embeddings]
seq_features.shape

torch.Size([20, 8])

In [17]:
corrs = torch.as_tensor(np.corrcoef(seq_features.detach()))

In [18]:
group_dict = dict()
num_groups = 0
seen = set()
for i in range(corrs.shape[0]):
    if i not in seen:
        flag = corrs[i].argmax()
        # What threshhold to pick here? there should be a better way?
        new_group = (corrs[i] >= 0.98).nonzero().tolist()
        print(i, len(new_group))
        new_group = [x[0] for x in new_group]
        if flag in group_dict:
            group_num = corrs[flag]
        else:
            group_num = num_groups
            num_groups += 1
        for x in new_group:
            group_dict[x] = group_num
            seen.add(x)

0 20


In [19]:
shape= [cfg.case.user.num_data_points, cfg.case.data.shape[0]]
sentence_labels = -torch.ones(corrs.shape[0], dtype=torch.long)
already_assigned = set()
for idx in range(corrs.shape[0]):
    if idx not in already_assigned:
        matches = (corrs[idx] >= 0.98).nonzero().squeeze(0)

        if len(matches) > 0:
            filtered_matches = torch.as_tensor([m for m in matches if m not in already_assigned])
            if len(filtered_matches) > shape[1]:
                filtered_matches = corrs[idx][filtered_matches].topk(k=shape[1]).indices
            sentence_labels[filtered_matches] = idx
sentence_labels

tensor([18, 19, 19, 19, 19, 19, 19, 18, 19, 19, 19, 19, 18, 19, 19, 19, 19, 19,
        18, 19])

# Reconstruct user data

In [20]:
attacker.cfg.sentence_algorithm = "k-means"

In [21]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

Recovered tokens [[5, 6, 7, 10, 11, 12, 16, 17, 22, 26, 29, 31, 32, 35, 38, 40], [50, 62, 63, 64, 72, 291, 310, 400, 494, 566, 846, 940, 993, 1084, 1495, 1936], [1, 6, 7, 7, 7, 1084, 1084, 5470, 6084, 6489, 8107, 8323, 18637, 21489, 21964, 22724], [5, 5, 6, 7, 7, 11, 12, 12, 566, 566, 846, 1084, 1084, 1936, 1936, 6084]] through strategy decoder-bias.
Recovered 57 embeddings with positional data from imprinted layer.
Assigned [16, 16, 9, 16] breached embeddings to each sentence.
tensor([-0.0451,  0.0095,  0.1452, -0.0913,  0.0872, -0.0906,  0.1386,  0.0664,
        -0.2368,  0.1638, -0.1356,  0.1633,  0.2619, -0.0111, -0.0466,  0.0112])
tensor([-0.0452, -0.0115, -0.0464,  0.1905,  0.1806,  0.0787, -0.0445,  0.0157,
        -0.0892, -0.2465,  0.0881,  0.0053,  0.1328, -0.0156,  0.1420,  0.2849])
tensor([0.8729, 0.8895, 0.8947, 0.8516, 0.8649, 0.8759, 0.2515, 0.1670, 0.1569])
tensor([0.0152, 0.1146, 0.8779, 0.0922, 0.8684, 0.0483, 0.1908, 0.8309, 0.8562,
        0.8446, 0.8470, 0.8952, 0.

# Manually compute attention

In [22]:
inputs = torch.randn(10, 4, 96)
inputs = user.model.pos_encoder.pe[0][None, 0:4, :]
inputs[0]

AttributeError: 'LearnablePositionalEmbedding' object has no attribute 'pe'

In [None]:
user.model.transformer_encoder.layers[0].self_attn.out_proj.weight.data.mul_(10000)

In [None]:
Q = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[:96, :]
K = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[96:192, :]
V = user.model.transformer_encoder.layers[0].self_attn.in_proj_weight[192:, :]
q_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[:96] * 10000
k_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[96:192]
v_b = user.model.transformer_encoder.layers[0].self_attn.in_proj_bias[192:]

O =  user.model.transformer_encoder.layers[0].self_attn.out_proj.weight.data

In [None]:
self_attn = user.model.transformer_encoder.layers[0].self_attn

In [None]:
self_attn.batch_first = True

In [None]:
Q.shape, inputs[0].T.shape, V.shape, K.shape, q_b.shape

In [None]:
inputs[0, 0,8:16]

In [None]:
Qv = ((Q@inputs[0].T).T + q_b)
Kv = ((K@inputs[0].T).T + k_b)
Vv = ((V@inputs[0].T).T + v_b)

In [None]:
M = (Qv @ Kv.T).softmax(dim=1)
M

In [None]:
((Qv @ Kv.T).softmax(dim=1) @ Vv) @ O

In [None]:
outputs, attn_outputs = self_attn(inputs, inputs, inputs)

In [None]:
user.model.transformer_encoder.layers[0].norm1.weight.shape

In [None]:
user.model.transformer_encoder.layers[0].norm1(outputs + inputs)

In [None]:
normy = torch.nn.LayerNorm(4)

In [None]:
a = torch.tensor([[1, 2, 1, 2], [5, 5, 5, 5], [7, 7,7, 7]]).float()
a

In [None]:
normy(a[None])

In [None]:
model.in_proj_bias[:96].pe[0]

In [None]:
outputs, attn_weights = model(inputs, inputs, inputs)

In [None]:
attn_weights

In [23]:
reconstructed_user_data, stats = attacker.reconstruct_single_sentence([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

Recovered tokens [[5, 6, 7, 10, 11, 12, 16, 17, 22, 26, 29, 31, 32, 35, 38, 40], [50, 62, 63, 64, 72, 291, 310, 400, 494, 566, 846, 940, 993, 1084, 1495, 1936], [1, 6, 7, 7, 7, 1084, 1084, 5470, 6084, 6489, 8107, 8323, 18637, 21489, 21964, 22724], [5, 5, 6, 7, 7, 11, 12, 12, 566, 566, 846, 1084, 1084, 1936, 1936, 6084]] through strategy decoder-bias.


TypeError: list indices must be integers or slices, not list

In [24]:
reconstructed_user_data, stats = attacker.reconstruct2([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)
user.print(reconstructed_user_data)
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

Recovered tokens [[5, 6, 7, 10, 11, 12, 16, 17, 22, 26, 29, 31, 32, 35, 38, 40], [50, 62, 63, 64, 72, 291, 310, 400, 494, 566, 846, 940, 993, 1084, 1495, 1936], [1, 6, 7, 7, 7, 1084, 1084, 5470, 6084, 6489, 8107, 8323, 18637, 21489, 21964, 22724], [5, 5, 6, 7, 7, 11, 12, 12, 566, 566, 846, 1084, 1084, 1936, 1936, 6084]] through strategy decoder-bias.


TypeError: list indices must be integers or slices, not list

In [None]:
permuted_true_data = dict(data=true_user_data["data"][[3, 2, 1, 0]], labels=true_user_data["labels"])

In [None]:
metrics = breaching.analysis.report(permuted_true_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
metrics = breaching.analysis.report(true_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)