# Breaching privacy - test parameter modifications

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

In [None]:
import matplotlib.pyplot as plt 

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=modern', 'case.model=resnet18', 'case/data=CIFAR10'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
# cfg.dryrun = False
# cfg.attack.optim.step_size=1.0
cfg.case.user.data_idx = 0
cfg.case.model = 'resnet50'
#cfg.case.server.model_state='moco'

cfg.attack.restarts.num_trials=1
cfg.attack.regularization.deep_inversion.scale=0.01
cfg.attack.regularization.total_variation.scale=0.05
cfg.attack.regularization.norm.scale=0.0
cfg.attack.optim.langevin_noise=0.0

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

### Simulate an attacked FL protocol

In [None]:
[(idx, val[0], val[1].shape) for idx, val in enumerate(server.model.named_parameters())]

In [None]:
layers.1.0.conv1

In [None]:
def plot_map(feature_map):
    min_val, max_val = feature_map.amin(dim=[2,3], keepdim=True), feature_map.amax(dim=[2,3], keepdim=True)
    renorm_map = (feature_map - min_val) / (max_val - min_val)
    plt.imshow(renorm_map[0, :3].permute(1, 2, 0).detach().cpu())

In [None]:
with torch.no_grad():
    for module in server.model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            module.weight.data = module.running_var.data.clone()
            module.bias.data = module.running_mean.data.clone() + 10
        if isinstance(module, torch.nn.Conv2d):
            torch.nn.init.zeros_(module.weight)
    for name, module in server.model.named_modules():
        if 'downsample.0' in name:
            torch.nn.init.dirac_(module.weight)
torch.nn.init.dirac_(server.model.stem[0].weight);

In [None]:
feature_shapes = dict()
features = dict()
def named_hook(name):
    def hook_fn(module, input, output):
        feature_shapes[name] = [input[0].shape, str(module)]
        features[name] = input[0]
    return hook_fn

hooks_list = []
for name, module in user.model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        hooks_list.append(module.register_forward_hook(named_hook(name)))
        
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# True user data is returned only for analysis
[h.remove() for h in hooks_list];

feature_shapes

In [None]:
idx = 'layers.1.0.conv1'
plot_map(features[idx])
print(feature_shapes[idx])

In [None]:
user.plot(true_user_data)

### Reconstruct user data [via optimization]:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, 
                                    true_user_data, server_payload, server.model, setup)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
attacker.regularizers

In [None]:
user.plot(reconstructed_user_data)