# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

In [None]:
import plotly.express as px
import pandas as pd

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case/data=stackoverflow', 
                                                      'attack=tag'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 32 # How many sentences?
cfg.case.user.user_idx = 0 # From which user?
cfg.case.data.shape = [128] # This is the sequence length

cfg.case.model="transformer3t"
cfg.case.server.pretrained=False
cfg.case.data.tokenizer = "bert-base-uncased"
cfg.case.data.task =  "causal-lm"
# cfg.case.data.vocab_size =  30522
cfg.case.data.disable_mlm=False
cfg.case.data.mlm_probability =  0.1

cfg.attack.attack_type = "permutation-optimization"
cfg.attack.label_strategy = "bias-text"

cfg.case.server.has_external_data=True

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

In [None]:
batch = next(iter(server.external_dataloader))

In [None]:
batch["input_ids"].shape

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
user.print(true_user_data)

In [None]:
true_user_data["labels"].unique()

# Reconstruct user data

In [None]:
# reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
#                                                       server.secrets, dryrun=cfg.dryrun)

In [None]:
# Finally return a dict with keys data and labels
reconstructed_user_data = dict(data=None, labels=None)

In [None]:
# user.print(reconstructed_user_data)

In [None]:
true_user_data["labels"].numel()

### Check metrics:

In [None]:
# metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
#                                    server.model, cfg_case=cfg.case, setup=setup)

In [None]:
[g.shape for g in shared_data["gradients"]]

# Tokens from decoder bias

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

# This is slightly modified analytic label recovery in the style of Wainakh
bias_per_query = [shared_data["gradients"][-1]]
token_list = []
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)
valid_classes = (average_bias < 0).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]
# tokens_in_input = shared_data["gradients"][0].norm(dim=-1).nonzero().squeeze(dim=-1)
# for token in tokens_in_input:
#     if token not in token_list:
#         token_list.append(token)

m_impact = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = average_bias.argmin()
    token_list.append(selected_idx)
    average_bias[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])

In [None]:
# Total token recovery:
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

In [None]:
# Unique token recovery:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
padded_classes = torch.cat([valid_classes.view(-1), torch.zeros(len(unique_tokens)-len(valid_classes))])
breaching.analysis.analysis.count_integer_overlap(padded_classes.view(-1), unique_tokens)

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["labels"].view(-1)
df = pd.DataFrame(dict(rec_labels=rec_labels.tolist(), true_labels=true_labels.tolist()))

In [None]:
fig = px.histogram(df, x=["rec_labels", "true_labels"], opacity=0.8,log_y=True, marginal="violin",
                  labels={'rec_labels':'Recovered tokens', "true_labels": "True tokens"})
fig.update_layout(
    title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay'
)
fig.update_traces(opacity=0.75)

fig.show()

# Tokens from encoder

In [None]:
shared_data["gradients"][-2].shape

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][-2]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 2.5 * std
valid_classes = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]

top2 = average_wte_norm.log().topk(k=2).values
# m_impact = top2[0] - top2[1]
m_impact = average_wte_norm[valid_classes].sum() / num_missing_tokens
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] - m_impact

average_wte_norm[valid_classes] = average_wte_norm[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_wte_norm[valid_classes].argmax()].squeeze()
    token_list.append(selected_idx)
    # print(selected_idx, average_wte_norm_log[selected_idx])
    average_wte_norm[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["labels"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1)[:len(unique_tokens)], 
                                                  unique_tokens)

In [None]:
data = dict(#true_tokens=shared_data["gradients"][0][true_user_data["data"].view(-1)].abs().sum(dim=-1).log().tolist(),
    all_tokens=shared_data["gradients"][-2].norm(dim=-1).log().tolist())

df = pd.DataFrame(data)
true_hits = shared_data["gradients"][-2].norm(dim=-1)[true_user_data["data"].view(-1).unique()]
df["true_tokens"] = pd.Series(true_hits.log().tolist())
fig = px.histogram(df, x=["all_tokens", "true_tokens"], opacity=0.8,log_y=False, marginal="violin")
fig.add_vline(x=cutoff)
fig.show()

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["data"].view(-1)
df = pd.DataFrame(dict(rec_labels=rec_labels.tolist(), true_labels=true_labels.tolist()))


fig = px.histogram(df, x=["rec_labels", "true_labels"], opacity=0.8,log_y=True, marginal="violin",
                  labels={'rec_labels':'Recovered tokens', "true_labels": "True tokens"})
fig.update_layout(
    title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay'
)
fig.update_traces(opacity=0.75)

fig.show()

# Mixed Strategy

Uniques from encoder, frequencies from decoder bias

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][-2]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 2.5 * std
valid_classes = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]


bias_per_query = [shared_data["gradients"][-1]]
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)

m_impact = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_bias[valid_classes].argmin()].squeeze()
    # selected_idx = average_bias.argmin()
    token_list.append(selected_idx)
    average_bias[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])

breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["labels"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1)[:len(unique_tokens)], 
                                                  unique_tokens)

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["data"].view(-1)
df = pd.DataFrame(dict(rec_labels=rec_labels.tolist(), true_labels=true_labels.tolist()))


fig = px.histogram(df, x=["rec_labels", "true_labels"], opacity=0.8,log_y=True, marginal="violin",
                  labels={'rec_labels':'Recovered tokens', "true_labels": "True tokens"})
fig.update_layout(
    title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay'
)
fig.update_traces(opacity=0.75)

fig.show()