# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [25]:
import torch
import hydra
from omegaconf import OmegaConf
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import plotly.express as px
import pandas as pd

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

Choose `case/data=` `shakespeare`, `wikitext`over `stackoverflow` here:

In [48]:
with hydra.initialize(config_path="../config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case/data=wikitext', 
                                                      'attack=tag'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

Investigating use case single_imagenet with server type honest_but_curious.


{'device': device(type='cpu'), 'dtype': torch.float32}

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [28]:
cfg.case.user.num_data_points = 27 # How many sentences?
cfg.case.user.user_idx = 50 # From which user?
cfg.case.data.shape = [512] # This is the sequence length

cfg.case.model="transformer3"
cfg.case.server.pretrained=False
cfg.case.data.tokenizer = "gpt2"
cfg.case.data.task =  "causal-lm"
# cfg.case.data.vocab_size =  30522
cfg.case.data.disable_mlm=False
cfg.case.data.mlm_probability =  0.1

cfg.attack.attack_type = "permutation-optimization"
cfg.attack.label_strategy = "bias-text"

cfg.case.server.has_external_data=True

### Instantiate all parties

In [29]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Reusing dataset wikitext (/home/jonas/data/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Model architecture gpt2 loaded with 124,439,808 parameters and 12,582,924 buffers.
Overall this is a data ratio of    9002:1 for target shape [27, 512] given that num_queries=1.
User (of type UserSingleStep) with settings:
    Number of data points: 27

    Threat model:
    User provides labels: False
    User provides buffers: False
    User provides number of data points: True

    Data:
    Dataset: wikitext
    user: 50
    
        
Server (of type HonestServer) with settings:
    Threat model: Honest-but-curious
    Number of planned queries: 1
    Has external/public data: True

    Model:
        model specification: gpt2
        model state: default
        public buffers: True

    Secrets: {}
    
A

In [30]:
batch = next(iter(server.external_dataloader))

In [31]:
batch["input_ids"].shape

torch.Size([128, 512])

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [32]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

Computing user update in model mode: eval.


In [33]:
user.print(true_user_data)

 Sir Robert Eric Mortimer Wheeler CH, CIE, MC, TD, FSA, FRS, FBA ( 10 September 1890 – 22 July 1976 ) was a British archaeologist and officer in the British Army. Over the course of his career, he served as Director of both the National Museum of Wales and London Museum, Director @-@ General of the Archaeological Survey of India, and the founder and Honorary Director of the Institute of Archaeology in London, further writing twenty @-@ four books on archaeological subjects. 
 Born in Glasgow to a middle @-@ class family, Wheeler was raised largely in Yorkshire before relocating to London in his teenage years. After studying Classics at University College London ( UCL ), he began working professionally in archaeology, specializing in the Romano @-@ British period. During World War I he volunteered for service in the Royal Artillery, being stationed on the Western Front, where he rose to the rank of major and was awarded the Military Cross. Returning to Britain, he obtained his doctorate

In [34]:
true_user_data["labels"].unique()

tensor([   11,    12,    13,  ..., 50004, 50067, 50099])

# Reconstruct user data

In [35]:
# reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
#                                                       server.secrets, dryrun=cfg.dryrun)

In [36]:
# Finally return a dict with keys data and labels
reconstructed_user_data = dict(data=None, labels=None)

In [37]:
# user.print(reconstructed_user_data)

In [38]:
true_user_data["labels"].numel()

13824

### Check metrics:

In [39]:
# metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
#                                    server.model, cfg_case=cfg.case, setup=setup)

In [40]:
[g.shape for g in shared_data["gradients"]]

[torch.Size([50257, 768]),
 torch.Size([1024, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 2304]),
 torch.Size([2304]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 3072]),
 torch.Size([3072]),
 torch.Size([3072, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 2304]),
 torch.Size([2304]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 3072]),
 torch.Size([3072]),
 torch.Size([3072, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 2304]),
 torch.Size([2304]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 3072]),
 torch.Size([3072]),
 torch.Size([3072, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 2304]),
 torch.Size([2304]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),


# Tokens from decoder bias

In [19]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

# This is slightly modified analytic label recovery in the style of Wainakh
bias_per_query = [shared_data["gradients"][-1]]
token_list = []
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)
valid_classes = (average_bias < 0).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]
# tokens_in_input = shared_data["gradients"][0].norm(dim=-1).nonzero().squeeze(dim=-1)
# for token in tokens_in_input:
#     if token not in token_list:
#         token_list.append(token)

m_impact = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = average_bias.argmin()
    token_list.append(selected_idx)
    average_bias[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])

In [20]:
# Total token recovery:
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["data"].view(-1))

0.9696180555555556

In [21]:
# Unique token recovery:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
padded_classes = torch.cat([valid_classes.view(-1), torch.zeros(len(unique_tokens)-len(valid_classes))])
breaching.analysis.analysis.count_integer_overlap(padded_classes.view(-1), unique_tokens)

3070 3075


0.9983739837398374

In [22]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["labels"].view(-1)
df = pd.DataFrame({"Recovered Token Distribution":rec_labels.tolist(), 
                   "True Token Distribution":true_labels.tolist()})

In [24]:
fig = px.histogram(df, x=["True Token Distribution", "Recovered Token Distribution", 
                          ], opacity=0.5,log_y=True, marginal="violin") # , nbins=50000
fig.update_layout(
    #title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    #bargap=0.2, # gap between bars of adjacent location coordinates
    #bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay',
    showlegend=True,
)
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.75,
    xanchor="right",
    x=0.99
))
# fig.update_traces(opacity=1.0)

fig.show()
fig.write_image(f"visualization_token_overlap_{cfg.case.model}.pdf")

In [None]:
import plotly.figure_factory as ff

# Group data together
hist_data = [df["true_labels"].tolist(), df["rec_labels"].tolist()]

group_labels = ['True Distribution', 'Estimated Distribution']

# Create distplot with custom bin_size
fig = ff.create_distplot(hist_data, group_labels)
fig.show()

fig.update_xaxes(title_text="x-axis in logarithmic scale", type="log")
fig.update_yaxes(title_text="y-axis in logarithmic scale", type="log")


In [113]:
fig = px.violin(df, x=["rec_labels", "true_labels"], box=True, # draw box plot inside the violin
                points='all', # can be 'outliers', or False
               )


fig.update_layout(
    title="Plot Title",
    xaxis_title="X Axis Title",
    yaxis_title="Y Axis Title",
    legend_title="Legend Title",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="RebeccaPurple"
    )
)
fig.show()

# Tokens from encoder

In [41]:
shared_data["gradients"][0].shape

torch.Size([50257, 768])

In [42]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][0]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 1.5 * std
valid_classes = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]

top2 = average_wte_norm.log().topk(k=2).values
# m_impact = top2[0] - top2[1]
m_impact = average_wte_norm[valid_classes].sum() / num_missing_tokens
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] - m_impact

average_wte_norm[valid_classes] = average_wte_norm[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_wte_norm[valid_classes].argmax()].squeeze()
    token_list.append(selected_idx)
    # print(selected_idx, average_wte_norm_log[selected_idx])
    average_wte_norm[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["labels"].view(-1))

0.9252748842592593

In [43]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][0]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 3.5 * std
valid_classes = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]

top2 = average_wte_norm.log().topk(k=2).values
# m_impact = top2[0] - top2[1]
m_impact = average_wte_norm[valid_classes].sum() / num_missing_tokens
# average_wte_norm_log[valid_classes] = average_wte_norm_log[valid_classes] - m_impact

average_wte_norm[valid_classes] = average_wte_norm[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_wte_norm[valid_classes].argmax()].squeeze()
    token_list.append(selected_idx)
    # print(selected_idx, average_wte_norm_log[selected_idx])
    average_wte_norm[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])
breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["labels"].view(-1))

0.7914496527777778

In [44]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1)[:len(unique_tokens)], 
                                                  unique_tokens)

843 3075


RuntimeError: The size of tensor a (843) must match the size of tensor b (3075) at non-singleton dimension 0

In [None]:
data = {"Gradients of all Embeddings":shared_data["gradients"][0].norm(dim=-1).log().tolist()}

df = pd.DataFrame(data)
true_hits = shared_data["gradients"][0].norm(dim=-1)[true_user_data["data"].view(-1).unique()]
df["Gradients of Embeddings of User Data"] = pd.Series(true_hits.log().tolist())
fig = px.histogram(df, x=["Gradients of all Embeddings", "Gradients of Embeddings of User Data"], 
                   opacity=0.8,log_y=False, marginal="violin")
fig.add_vline(x=cutoff)

fig.update_layout(
    #title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Log of Token Embedding Norm', # xaxis label
    yaxis_title_text='Count', # yaxis label
    #bargap=0.2, # gap between bars of adjacent location coordinates
    #bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay',
    showlegend=True,
)
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.75,
    xanchor="right",
    x=0.99
))
# fig.update_traces(opacity=1.0)
fig.write_image(f"visualization_log_of_tied_norms_{cfg.case.model}.pdf")
fig.show()

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["labels"].view(-1)
df = pd.DataFrame({"Recovered Token Distribution":rec_labels.tolist(), 
                   "True Token Distribution":true_labels.tolist()})


fig = px.histogram(df, x=["Recovered Token Distribution", 
                          "True Token Distribution"], opacity=0.5,log_y=True, marginal="violin") # , nbins=50000
fig.update_layout(
    #title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    #bargap=0.2, # gap between bars of adjacent location coordinates
    #bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay',
    showlegend=True,
)
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.75,
    xanchor="right",
    x=0.99
))
# fig.update_traces(opacity=1.0)

fig.show()
fig.write_image(f"visualization_token_overlap_{cfg.case.model}.pdf")

In [46]:
os.getcwd()

'/home/jonas'

In [47]:
fig = px.histogram(df, x=["True Token Distribution", "Recovered Token Distribution", 
                          ], opacity=0.5,log_y=True, marginal="violin") # , nbins=50000
fig.update_layout(
    #title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    #bargap=0.2, # gap between bars of adjacent location coordinates
    #bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay',
    showlegend=True,
)
fig.update_layout(legend=dict(
    yanchor="top",
    y=0.75,
    xanchor="right",
    x=0.99
))
# fig.update_traces(opacity=1.0)

fig.show()
fig.write_image(f"visualization_token_overlap_{cfg.case.model}.pdf")

# Mixed Strategy

Uniques from encoder, frequencies from decoder bias

In [None]:
data_shape = cfg.case.data.shape
num_data_points = cfg.case.user.num_data_points

num_missing_tokens = num_data_points * data_shape[0]

wte_per_query = [shared_data["gradients"][-2]]
token_list = []
# Stage 1
average_wte_norm = torch.stack(wte_per_query).mean(dim=0).norm(dim=1)
std, mean = torch.std_mean(average_wte_norm.log())
cutoff = mean + 2.5 * std
valid_classes = (average_wte_norm.log() > cutoff).nonzero()
token_list += [*valid_classes.squeeze(dim=-1)]


bias_per_query = [shared_data["gradients"][-1]]
# Stage 1
average_bias = torch.stack(bias_per_query).mean(dim=0)

m_impact = average_bias[valid_classes].sum() / num_missing_tokens

average_bias[valid_classes] = average_bias[valid_classes] - m_impact
# Stage 2
while len(token_list) < num_missing_tokens:
    selected_idx = valid_classes[average_bias[valid_classes].argmin()].squeeze()
    # selected_idx = average_bias.argmin()
    token_list.append(selected_idx)
    average_bias[selected_idx] -= m_impact
tokens = torch.stack(token_list).view(num_data_points, data_shape[0])

breaching.analysis.analysis.count_integer_overlap(tokens.view(-1), true_user_data["labels"].view(-1))

In [None]:
unique_tokens = true_user_data["data"].view(-1).unique()
print(len(valid_classes), len(unique_tokens))
breaching.analysis.analysis.count_integer_overlap(valid_classes.view(-1)[:len(unique_tokens)], 
                                                  unique_tokens)

In [None]:
rec_labels = tokens.view(-1)
true_labels = true_user_data["data"].view(-1)
df = pd.DataFrame(dict(rec_labels=rec_labels.tolist(), true_labels=true_labels.tolist()))


fig = px.histogram(df, x=["rec_labels", "true_labels"], opacity=0.8,log_y=True, marginal="violin",
                  labels={'rec_labels':'Recovered tokens', "true_labels": "True tokens"})
fig.update_layout(
    title_text='Recovered Token Frequency', # title of plot
    xaxis_title_text='Token ID', # xaxis label
    yaxis_title_text='Count', # yaxis label
    bargap=0.2, # gap between bars of adjacent location coordinates
    bargroupgap=0.1, # gap between bars of the same location coordinates
    barmode='overlay'
)
fig.update_traces(opacity=0.75)

fig.show()