# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

In [None]:
import plotly.express as px
import pandas as pd
import matplotlib.pyplot as plt

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [None]:
with hydra.initialize(config_path="../config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case/data=wikitext', "case/server=malicious-transformer",
                                                      'attack=tag'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 32 # How many sentences?
cfg.case.user.user_idx = 1 # From which user?
cfg.case.data.shape = [32] # This is the sequence length

cfg.case.model=  "bert-base-uncased"   # "huawei-noah/TinyBERT_General_4L_312D"
cfg.case.server.pretrained=False
cfg.case.data.tokenizer = "bert-base-uncased"    # "huawei-noah/TinyBERT_General_4L_312D"
cfg.case.data.task =  "masked-lm"
cfg.case.data.vocab_size =  30522
cfg.case.data.disable_mlm=False
cfg.case.data.mlm_probability =  0.15

cfg.attack.attack_type = "permutation-optimization"
cfg.attack.label_strategy = "bias-text"

cfg.case.server.param_modification.v_length = 8
cfg.case.server.param_modification.eps = 1e-6
cfg.case.server.param_modification.imprint_sentence_position = 0
cfg.case.server.param_modification.softmax_skew = 100000000
cfg.case.server.param_modification.sequence_token_weight = 1
cfg.case.server.param_modification.measurement_scale = 1

cfg.case.server.param_modification.equalize_token_weight = 10

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

In [None]:
#user.model.model.cls.predictions.decoder.weight = torch.nn.Parameter(user.model.model.cls.predictions.decoder.weight.detach().clone())
#server.model.model.cls.predictions.decoder.weight = torch.nn.Parameter(server.model.model.cls.predictions.decoder.weight.detach().clone())

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
user.print(true_user_data)

# Reconstruct user data

In [None]:
# reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
#                                                       server.secrets, dryrun=cfg.dryrun)

In [None]:
# Finally return a dict with keys data and labels
reconstructed_user_data = dict(data=None, labels=None)

In [None]:
# user.print(reconstructed_user_data)

### Check metrics:

In [None]:
# metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
#                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
user.model.model.cls.predictions.decoder.weight is user.model.model.bert.embeddings.word_embeddings.weight

In [None]:
[g.shape for g in shared_data["gradients"]]

# Tokens from decoder bias

In [None]:
cfg.case.data.vocab_size

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

# This is slightly modified analytic label recovery in the style of Wainakh
bias_per_query = [shared_data["gradients"][-5]]
assert len(bias_per_query[0]) == cfg.case.data.vocab_size
token_list = []
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)
valid_classes = (average_bias < 0).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]
# tokens_in_input = shared_data["gradients"][0].norm(dim=-1).nonzero().squeeze(dim=-1)
# for token in tokens_in_input:
#     if token not in token_list:
#         token_list.append(token)

m_impact = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = average_bias.argmin(dim=-1)
    # selected_idx = valid_classes[average_bias[valid_classes].argmin()].squeeze()
    # if average_bias[selected_idx]  - m_impact< 0:
    token_list.append(selected_idx)
    average_bias[selected_idx] -= m_impact
#     else:
#         token_list.append(torch.tensor(0))
# # Stage 2
# while len(token_list) < num_missing_tokens:
#     token_list.append(valid_classes[torch.randint(0, len(valid_classes), (1,))].squeeze())

    
    # print(val, average_bias[selected_idx], selected_idx)
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
# Total token recovery:
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

In [None]:
len(valid_classes)

In [None]:
# All tokens after average_bias[selected_idx]  - m_impact< 0 are useless

In [None]:
# Unique token recovery:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(unique_tokens), len(valid_classes))
padded_classes = torch.cat([valid_classes.view(-1), torch.zeros(len(unique_tokens)-len(valid_classes))])
breaching.analysis.analysis.count_integer_overlap(padded_classes.view(-1), unique_tokens)

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["labels"].view(-1)
df = pd.DataFrame(dict(rec_labels=rec_labels.tolist(), true_labels=true_labels.tolist()))

In [None]:
fig = px.histogram(df, x=["rec_labels", "true_labels"], opacity=0.8,log_y=True, marginal="violin",
                  labels={'rec_labels':'Recovered tokens', "true_labels": "True tokens"})
fig.update_layout(
    title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay'
)
fig.update_traces(opacity=0.75)

fig.show()

In [None]:
mask_tokens = true_user_data["labels"].view(-1).unique()[1:]
breaching.analysis.analysis.count_integer_overlap(mask_tokens.view(-1), valid_classes.view(-1))

### aside from the masked biases nothing nice is in the decoder bias

In [None]:
data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    all_bias=average_bias.sort().values.tolist())
df = pd.DataFrame(data)
true_bias = average_bias[true_user_data["data"].view(-1).unique()].sort().values
df["true_bias"] = pd.Series(true_bias.tolist())
false_bias = average_bias[~true_user_data["data"].view(-1).unique()].sort().values
df["false_bias"] = pd.Series(false_bias.tolist())

mask_bias = average_bias[average_bias < 0].sort().values
df["mask_bias"] = pd.Series(mask_bias.tolist())

fig = px.histogram(df, x=["false_bias", "true_bias", "mask_bias"], opacity=0.8,log_y=True, marginal="violin")
fig.show()

# Tokens from encoder

In [None]:
shared_data["gradients"][0].shape

In [None]:
average_wte = torch.stack(wte_per_query).mean(dim=0)
average_wte.shape

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][0]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
# average_wte = torch.stack(wte_per_query).mean(dim=0)
# average_wte = average_wte - average_wte.mean(dim=1)
# average_wte_norm = average_wte.norm(dim=1)



std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 3 * std
if not cutoff.isfinite():  # tied weights
    valid_classes = average_wte_norm.nonzero().squeeze(dim=-1)
else:  # untied weights
    valid_classes = (average_wte_norm.log() > cutoff).nonzero().squeeze(dim=-1)

token_list += [*valid_classes.squeeze(dim=-1)]

#top2 = average_wte_norm.log().topk(k=2).values
# m_impact = top2[0] - top2[1]
# m_impact = 0.0010 
m_impact = average_wte_norm[valid_classes].sum() / num_missing_tokens
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] - m_impact

average_wte_norm[valid_classes] = average_wte_norm[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_wte_norm[valid_classes].argmax()]
    token_list.append(selected_idx)
    average_wte_norm[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()


data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    valid_norms=average_wte_norm.log()[valid_classes.squeeze()].tolist())

df = pd.DataFrame(data)
true_norms = average_wte_norm[unique_tokens]
df["true_norms"] = pd.Series(true_norms.log().tolist())

true_dist = average_wte_norm[ true_user_data["data"].view(-1)]
df["true_dist"] = pd.Series(true_dist.log().tolist())


fig = px.histogram(df, x=["valid_norms", "true_dist"], opacity=0.5,log_y=False, marginal="violin")
#fig.add_vline(x=cutoff)
fig.show()

In [None]:
# freq_lookup = {token.item():freq.item() for (token, freq) in zip(*true_user_data["data"].view(-1).unique(return_counts=True))}
# freq_lookup = dict(sorted(freq_lookup.items(), key=lambda item: item[1], reverse=True))
# [freq_lookup[k.item()] for k in average_wte_norm.topk(k=15).indices]

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(unique_tokens), len(valid_classes))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1), 
                                                  unique_tokens[:len(valid_classes)])

In [None]:
data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    all_tokens=shared_data["gradients"][0].norm(dim=-1).log().tolist())

df = pd.DataFrame(data)
true_hits = shared_data["gradients"][0].norm(dim=-1)[true_user_data["data"].view(-1).unique()]
df["true_tokens"] = pd.Series(true_hits.log().tolist())
fig = px.histogram(df, x=["all_tokens", "true_tokens"], opacity=0.8,log_y=False, marginal="violin")
fig.add_vline(x=cutoff)
fig.show()

# Tokens from encoder log

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][0]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 3 * std
if not cutoff.isfinite():  # tied weights
    valid_classes = average_wte_norm.nonzero().squeeze(dim=-1)
else:  # untied weights
    valid_classes = (average_wte_norm.log() > cutoff).nonzero().squeeze(dim=-1)
# std, mean = torch.std_mean(average_wte_norm.log())
# cutoff = mean + 3 * std
# valid_classes = (average_wte_norm.log() > cutoff).nonzero().squeeze(dim=-1)

token_list += [*valid_classes]

# average_bias = torch.stack(bias_per_query).mean(dim=0)
# tokens_in_mask = (average_bias < 0).nonzero().squeeze(dim=-1)
# for token in tokens_in_mask:
#     if token not in token_list:
#         token_list.append(token)
#         print("app")

#top2 = average_wte_norm.log().topk(k=2).values
#m_impact = top2[0] - top2[1]
# m_impact = 0.0010 
# m_impact = average_wte_norm[valid_classes].median()

average_wte_norm_log = average_wte_norm.log()
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] / valid_classes.log()
m_impact = average_wte_norm_log[valid_classes].max() / torch.as_tensor(num_data_points).sqrt()
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] - m_impact
# average_wte_norm[valid_classes] = average_wte_norm[valid_classes] - m_impact

# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_wte_norm_log[valid_classes].argmax()].squeeze()
    token_list.append(selected_idx)
    average_wte_norm_log[selected_idx] -= m_impact
    # print(selected_idx, average_wte_norm[selected_idx])
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(unique_tokens), len(valid_classes))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1)[:len(unique_tokens)], 
                                                  unique_tokens)

In [None]:
data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    all_tokens=shared_data["gradients"][0].norm(dim=-1).tolist())

df = pd.DataFrame(data)
true_hits = shared_data["gradients"][0].norm(dim=-1)[true_user_data["data"].view(-1).unique()]
df["true_tokens"] = pd.Series(true_hits.tolist())
fig = px.histogram(df, x=["all_tokens", "true_tokens"], opacity=0.8,log_y=True, marginal="violin")
# fig.add_vline(x=cutoff)
fig.show()

# Mixed Strategy

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

# This is slightly modified analytic label recovery in the style of Wainakh
bias_per_query = [shared_data["gradients"][-5]]
token_list = []
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)
valid_classes = (average_bias < 0).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]
# tokens_in_input = shared_data["gradients"][0].norm(dim=-1).nonzero().squeeze(dim=-1)
# for token in tokens_in_input:
#     if token not in token_list:
#         token_list.append(token)

m_impact = average_bias_correct_label = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = average_bias.argmin()
    if average_bias[selected_idx]  - m_impact< 0:
        token_list.append(selected_idx)
        average_bias[selected_idx] -= m_impact
    break
    
missing_tokens = num_missing_tokens - len(token_list)
    
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 2.5 * std
uniques_from_wte = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*uniques_from_wte.squeeze(dim=-1)]
# token_list += [*uniques_from_wte.squeeze(dim=-1)]

missing_tokens = num_missing_tokens - len(token_list)


token_list = [*token_list, *torch.zeros(num_missing_tokens - len(token_list))]
# # Stage 2
# while len(token_list) < num_missing_tokens:
#     token_list.append(valid_classes[torch.randint(0, len(valid_classes), (1,))].squeeze())

    
    # print(val, average_bias[selected_idx], selected_idx)
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
# Total token recovery:
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
valid_classes = tokens.view(-1).unique()

print(len(unique_tokens), len(valid_classes))
breaching.analysis.analysis.count_integer_overlap(valid_classes[:len(unique_tokens)], unique_tokens)

In [None]:
data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    all_tokens=shared_data["gradients"][0].norm(dim=-1).log().tolist())

df = pd.DataFrame(data)
true_hits = shared_data["gradients"][0].norm(dim=-1)[true_user_data["data"].view(-1).unique()]
df["true_tokens"] = pd.Series(true_hits.log().tolist())
fig = px.histogram(df, x=["all_tokens", "true_tokens"], opacity=0.8,log_y=False, marginal="violin")
fig.add_vline(x=cutoff)
fig.show()