# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

In [None]:
import matplotlib.pyplot as plt
from statistics import NormalDist
import math

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=imprint', 'case/server=malicious-model'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=torch.float)
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
target_batch_size = 2**6
estimate_prob = breaching.analysis.imprint_guarantee.one_shot_guarantee(target_batch_size, target_batch_size)

print(f'{estimate_prob:.2%} of 1-shot attacks will perfectly uncover a single data point '
      f'out of the given batch of {target_batch_size} data points.')

In [None]:
cfg.case.user.num_data_points = target_batch_size

cfg.case.model = 'none'

cfg.case.server.model_modification.type = 'OneShotBlock' 
cfg.case.server.model_modification.num_bins = target_batch_size
cfg.case.server.model_modification.position = None # '4.0.conv'
cfg.case.server.model_modification.connection = 'add'


cfg.case.server.model_modification.linfunc = 'fourier'
cfg.case.server.model_modification.mode = 32

cfg.case.user.local_diff_privacy.gradient_noise = 0.0
cfg.case.user.local_diff_privacy.per_example_clipping = 0.0
cfg.case.user.local_diff_privacy.distribution = 'gaussian'


cfg.case.server.model_modification.gain = 1
cfg.case.server.model_gain = 1

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

In [None]:
user.model[1].bins

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)

In [None]:
grads = shared_data['gradients'][0]
[(g.norm().item(), g.mean().item(), g.std().item()) for g in grads]

In [None]:
torch.norm(torch.stack([torch.norm(g, 2) for g in grads]), 2)

# Reconstruct user data

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
found_data = dict(data = reconstructed_user_data['data'][1:2], labels=None)
user.plot(found_data, scale=False)

### Identify id of user data with this value:

In [None]:
block_lin = user.model[1].linear0.weight[0]
matches = dict()
for idx, data in enumerate(true_user_data['data']):
    matches[idx] = torch.dist(found_data['data'], data)
idx = min(matches, key=matches.get)
print(idx)
matching_user_data = dict(data = true_user_data['data'][idx][None,...], labels=true_user_data['labels'][idx:idx+1])
user.plot(matching_user_data, scale=False)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(found_data, matching_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

## GT analysis

In [None]:
block_lin = user.model[1].linear0.weight[0]

vals = torch.zeros(len(user.dataloader.dataset))
counter = 0
with torch.inference_mode():
    for image, _ in iter(user.dataloader):
        image = image.to(**setup) 
        image = image / image.sum() + 2 * user.generator.sample(image.shape)
        B = image.shape[0]
        vals[counter:counter+B] = (block_lin[None, :] * image.flatten(start_dim=1)).sum(dim=1).detach().cpu()
        counter += B

In [None]:
from scipy.stats import laplace

In [None]:
std_data, mu_data = torch.std_mean(vals)

# the histogram of the data
n, bins, patches = plt.hist(vals.numpy(), 250, density=True, facecolor='royalblue', alpha=0.75)


plt.xlabel('Values')
plt.ylabel('Probability')
#plt.xlim(-4, 4)
plt.title('Distribution on GT data')

xmin, xmax = plt.xlim()
x = torch.linspace(xmin, xmax, 100)
# p = [NormalDist(mu, std).pdf(xx) for xx in x]
# plt.plot(x, p, 'k', linewidth=2)

# p = [NormalDist(mu_data, std_data).pdf(xx) for xx in x]
# plt.plot(x, p, 'g', linewidth=2, label='Gaussian estimate on true data distribution')

# p = [NormalDist(0, 1).pdf(xx) for xx in x]
# plt.plot(x, p, 'b', linewidth=2, label='Normal Distribution')

loc, scale = laplace.fit(vals.numpy())
p = laplace.pdf(x, loc, scale)
plt.plot(x, p, 'r', linewidth=2, label='Laplacian estimate on true data distribution')

p = laplace.pdf(x, 0, 1/math.sqrt(2))
plt.plot(x, p, 'k', linewidth=2, label='Laplacian distribution (scaled = 1/sqrt(2))')

plt.legend()
plt.show()

In [None]:
std_data, mu_data