# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [40]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [41]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case/data=shakespeare', 
                                                      'case.model=transformer3',
                                                    'case.server.has_external_data=True'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)

Investigating use case single_imagenet with server type honest_but_curious.


### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [42]:
cfg.case.user.num_data_points = 5 # How many sentences?
cfg.case.user.user_idx = 0 # From which user?
cfg.case.data.shape = [30] # This is the sequence length

### Instantiate all parties

In [43]:
model, loss_fn = breaching.cases.construct_model(cfg.case.model, cfg.case.data, pretrained=True)
# Server:
server = breaching.cases.construct_server(model, loss_fn, cfg.case, setup)

Now processing user ALL_S_WELL_THAT_ENDS_WELL_CELIA.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

## Modify the model here:

In [44]:
import math
from breaching.attacks.analytic_transformer_utils import * 
proportion = 8/96
portion = int(proportion * model.encoder.embedding_dim)
weights = torch.randn(model.encoder.embedding_dim)
weights[:portion] = torch.zeros(portion)
std, mu = torch.std_mean(weights)
measurement = (weights - mu) / std / math.sqrt(model.encoder.embedding_dim) # Here's our linear measurement
std, mean = feature_distribution(model, server, measurement)
make_imprint_layer(model, measurement, mean, std)

Computing feature distribution before the linear1 layer from external data.
Feature mean is -0.053457196801900864, feature std is 0.6983630657196045.


In [45]:
set_MHA(model, server, sequence_token_weight=100, pos=0, attention_block=1, v_proportion=proportion)
std, mean = feature_distribution(model, server, measurement, block_num=1)
make_imprint_layer(model, measurement, mean, std, block_num=1, self_attn=True)
model.transformer_encoder.layers[2].linear1.weight.data =\
    torch.zeros_like(model.transformer_encoder.layers[2].linear1.weight.data)

torch.Size([5, 30])
torch.Size([5, 30, 96])
Computing feature distribution before the linear1 layer from external data.
Feature mean is -0.0047038705088198185, feature std is 0.025192659348249435.


In [46]:
# Finalize changes:
model = server.vet_model(model)

In [47]:
# Construct the user here:
user = breaching.cases.construct_user(model, loss_fn, cfg.case, setup)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [48]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload) 

# Reconstruct user data

### Preliminary getting gradients, and recovering embeddings at different places

In [49]:
grad_dict = dict([(k,v) for (k,_), (v) in zip(model.named_parameters(), shared_data['gradients'])])
leaked_tokens = ((grad_dict['encoder.weight'] != 0).sum(dim=1) > 0).nonzero().squeeze().cpu() # Bag of words tokens

weight_grad = grad_dict['transformer_encoder.layers.0.linear1.weight'].detach().clone().cpu()
bias_grad = grad_dict['transformer_encoder.layers.0.linear1.bias'].detach().clone().cpu()

for i in reversed(list(range(1, weight_grad.shape[0]))):
    weight_grad[i] -= weight_grad[i - 1]
    bias_grad[i] -= bias_grad[i - 1]
valid_classes = bias_grad != 0

pos_recs = weight_grad[valid_classes, :] / bias_grad[valid_classes, None] # Here are our reconstructed positionally encoded features
no_pos_recs = model.transformer_encoder.layers[0].norm1((model.encoder(leaked_tokens) * math.sqrt(model.encoder.embedding_dim))).cpu()


weight_grad = grad_dict['transformer_encoder.layers.1.linear1.weight'].detach().clone().cpu()
bias_grad = grad_dict['transformer_encoder.layers.1.linear1.bias'].detach().clone().cpu()

for i in reversed(list(range(1, weight_grad.shape[0]))):
    weight_grad[i] -= weight_grad[i - 1]
    bias_grad[i] -= bias_grad[i - 1]
valid_classes = bias_grad != 0

attn_recs = weight_grad[valid_classes, :] / bias_grad[valid_classes, None] # Here are our reconstructed post-mha embeddings
print(pos_recs.shape)
print(attn_recs.shape)

torch.Size([145, 96])
torch.Size([145, 96])


### Group words into sentences

In [50]:
# Let's group words into sentence groups
import numpy as np
from scipy.optimize import linear_sum_assignment # Better than greedy search? 

# Now, we need to map positionally encoded tokens to recovered "attended" embeddings
pos_attn_coeffs = torch.zeros((len(pos_recs), len(attn_recs)))
for i in range(len(pos_recs)):
    for j in range(len(attn_recs)):
        pos_attn_coeffs[i,j] = np.corrcoef(pos_recs[i][portion:].detach().numpy(), attn_recs[j][portion:].detach().numpy())[0,1]
row_ind, col_ind = linear_sum_assignment(pos_attn_coeffs.numpy(), maximize=True)
assignment_list = [(y,pos_recs[x]) for (x,y) in zip(row_ind, col_ind)]
pos_lookup = dict(assignment_list)

# Next, we need to map the attn embeddings to groups
group_coeffs = torch.zeros((len(attn_recs), len(attn_recs)))
for i in range(len(attn_recs)):
    for j in range(len(attn_recs)):
        group_coeffs[i,j] = np.corrcoef(attn_recs[i][:portion].detach().numpy(), attn_recs[j][:portion].detach().numpy())[0,1]

In [67]:
group_dict = dict()
num_groups = 0
seen = set()
for i in range(group_coeffs.shape[0]):
    if i not in seen:
        flag = group_coeffs[i].argmax()
        # What threshhold to pick here? there should be a better way? 
        new_group = (group_coeffs[i] >= 0.95).nonzero().tolist() 
        print(i, len(new_group))
        new_group = [x[0] for x in new_group]
        if flag in group_dict:
            group_num = group_coeffs[flag]
        else:
            group_num = num_groups
            num_groups += 1
        for x in new_group: 
            group_dict[x] = group_num
            seen.add(x)

from collections import defaultdict
new_group_dict = defaultdict(list)
for k, v in group_dict.items():
    if k in pos_lookup:
        new_group_dict[v].append(pos_lookup[k])

0 30
1 29
2 55
57 1
63 29
129 1


In [56]:
user.print(true_user_data)

Yonder comes my master, your brother.But do not so. I have five hundred crowns,I scarce can speak to thank you for myself
.
[Coming forward] Sweet masters, be patient; for your father'sremembrance, be at accord.
Is 'old dog'
 my reward? Most true, I have lost my teeth inCome not within these doors; within this roof
The enemy of all your graces lives
.
Your brother- no, no brother; yet the son-
Yet not the son; I will not call him son
Of him I
 was about to call his father-
Hath heard your praises; and this night he means
To burn the lodging where you use to lie,


In [62]:
sentences = []
for group in new_group_dict.keys(): 
    sentences.append(recover_from_group(model, server, new_group_dict[group], no_pos_recs, leaked_tokens))

In [63]:
# Finally return a dict with keys data and labels
reconstructed_user_data = dict(data=sentences, labels=None)

In [64]:
user.print(reconstructed_user_data)

 wasonder comes my master your.But do not soath I this five hundred crowns beI scarce can speak to thank you for myselfY have
 was about to call his father-
Hath heard your praises; and he means theseTo burn the where you useold lie, nightComing
.
[But forward],TheH speak patient where for heard father's trueembranceathYet accord call teeth son 'Of him'YourComing- Sweet mastersonder yet Most your theserem the I will not gr praises dogaces no brother be; thank canIs
 at
 my? Most true, I have lost teeth inCome not within these doors;'roof
The enemy of all your graces livesonder this
 reward


### Check metrics: (can't get to work for now)

In [68]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, cfg_case=cfg.case, setup=setup)

ValueError: Mismatch in the number of predictions (29) and references (30)