# Breaching privacy

This notebook does the same job as the cmd-line tool `simulate_breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case/data=shakespeare', 'case.model=transformer1',
                                                      'attack.label_strategy=bias-text',
                                                      'attack.regularization.total_variation.scale=0'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.num_data_points = 21 # 21
cfg.case.user.user_idx=0

cfg.attack.attack_type = "joint-optimization"

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
#grads = shared_data['gradients'][0]
#[(g.norm(), g.mean(), g.std()) for g in grads]

In [None]:
#torch.norm(torch.stack([torch.norm(g, 2) for g in grads]), 2)

In [None]:
user.print(true_user_data)

In [None]:
true_user_data["data"].shape

# Reconstruct user data

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
user.print(reconstructed_user_data)

# Reconstruct manually

In [None]:
labels = reconstructed_user_data["labels"]

In [None]:
labels

In [None]:
input_to_pos_embedding = shared_data["gradients"][0]
input_to_pos_embedding.shape

In [None]:
user.model.transformer_encoder.layers[0].self_attn.out_proj.weight.shape

In [None]:
named_grads = dict(zip([name for name, _ in attacker._rec_models[0].named_parameters()], shared_data["gradients"]))
[(key, g.shape) for key, g in named_grads.items()]

In [None]:
encoded_labels = user.model.encoder.weight[labels]
encoded_labels.shape

In [None]:
pos_grads = named_grads["pos_encoder.embedding.weight"][:32]
pos_grads.shape

In [None]:
cmap = (encoded_labels[None] *  pos_grads[:, None]).pow(2).sum(dim=-1)
norms = encoded_labels.norm(dim=-1)[None] * pos_grads[:, None].norm(dim=-1) 

In [None]:
cmap.abs().min()

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow((cmap / norms).detach())

In [None]:
(true_user_data["data"].view(-1).sort()[0] - labels.view(-1).sort()[0] == 0).sum() / labels.numel()

In [None]:
found_labels = 0
label_pool = true_user_data["data"].view(-1).clone().tolist()
for label in labels.view(-1):
    if label in label_pool:
        found_labels += 1
        label_pool.remove(label)
found_labels / true_user_data["labels"].numel()

In [None]:
21*32