# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

import copy
import matplotlib.pyplot as plt 

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=invertinggradients',
                                                      'case=1_single_image_small'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.data_idx = 0
cfg.case.model='ConvNetSmall'

cfg.case.user.num_data_points=10


cfg.case.server.model_state = 'trained'

# The total variation scale should be small for CIFAR images
cfg.attack.regularization.total_variation.scale = 1e-4

cfg.attack.objective.type = 'cosine-similarity'
cfg.attack.objective.scale = 1
cfg.attack.optim.signed=False

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

server.model.to(**setup)

## Malicious server: Modify the model parameters here

In [None]:
bias_set = False
with torch.no_grad():
    for module in server.model.modules():
        # if isinstance(module, torch.nn.BatchNorm2d):
            # module.weight.data = module.running_var.data.clone()
            # module.bias.data = module.running_mean.data.clone() + 10
        if isinstance(module, torch.nn.Conv2d):
# Grouping stuff:
#             num_groups = module.out_channels // module.in_channels
#             surviving_features = module.weight[:module.in_channels, 0:1]
#             torch.nn.init.orthogonal_(surviving_features)
#             module.weight.data = torch.zeros_like(module.weight)
                
#             idx = 0
#             for group in range(num_groups):
#                 module.weight.data[idx:idx+module.in_channels, group:group+1] = surviving_features
#                 idx += module.in_channels

# Other replication stuff:            
            initial_filters = module.weight[0:1]
            torch.nn.init.orthogonal_(initial_filters)
            #torch.nn.init.dirac_(initial_filters)
            # torch.nn.init.constant_(initial_filters, 1.0)
            #print(initial_filters.data)
            initial_filters = torch.eye(3, **setup).repeat(1, module.in_channels, 1, 1)
            module.weight.data = torch.cat([initial_filters] * module.out_channels).contiguous()
            module.bias.data = torch.zeros_like(module.bias.data)
            
        if isinstance(module, torch.nn.Linear):
            # module.bias.data = torch.zeros_like(module.bias.data)
            # module.weight.data = torch.cat([module.weight.data[0:1]] * module.weight.shape[0], dim=0).contiguous()
            torch.nn.init.orthogonal_(module.weight.data)
            
#     for module in user.model.modules():
#         if isinstance(module, torch.nn.Conv2d):
#             module.groups = module.in_channels        
#     for module in attacker.model_template.modules():
#         if isinstance(module, torch.nn.Conv2d):
#             module.groups = module.in_channels   

### Mess up activations?

In [None]:
class ModdedHardTanh(torch.nn.Module):
    def __init__(self, min_val=-1, max_val=1):
        super().__init__()
        self.hardtanh = torch.nn.Hardtanh(min_val, max_val)
        self.min_val = self.hardtanh.min_val
        self.max_val = self.hardtanh.max_val
    def forward(self, inputs):
        return (self.hardtanh(inputs) + 1) / 2

In [None]:
def convert_relu_to(model, activation=torch.nn.Sigmoid, args=[]):
    for child_name, child in model.named_children():
        if isinstance(child, torch.nn.ReLU):
            setattr(model, child_name, activation(*args))
        else:
            convert_relu_to(child, activation, args)
            

new_activation = ModdedHardTanh
args = [0, 1]

convert_relu_to(server.model, activation=new_activation, args=args)
convert_relu_to(user.model, activation=new_activation, args=args)
convert_relu_to(attacker.model_template, activation=new_activation, args=args)

In [None]:
server.model

### Space biases

In [None]:
features = dict()
in_features = dict()
def named_hook(name):
    def hook_fn(module, input, output):
        features[name] = output
        in_features[name] = input[0]
    return hook_fn

In [None]:
with torch.no_grad():
    for name, module in server.model.named_modules():
        if isinstance(module, (torch.nn.Conv2d)):
            hook = module.register_forward_hook(named_hook(name))


            # random_data_sample = torch.randn(1024, 3, 32, 32, **setup)
            random_data_sample = true_user_data['data'] #ground truth data sample for testing

            module.bias.data = torch.zeros_like(module.bias)

            server.model(random_data_sample)
            std, mu = torch.std_mean(features[name])
            print(f'mean of layer {name} is {mu.item()}, std is {std.item()}')
            with torch.no_grad():
                module.weight.data = module.weight.data / (std  + 1e-6)
                module.bias.data = torch.linspace(-1.96 - mu, 1.96 + mu, module.bias.numel()).to(**setup)
                bin_val = module.bias.data[1] - module.bias.data[0]
                print(bin_val)
            
            server.model(random_data_sample)
            std, mu = torch.std_mean(features[name])
            print(f'mean of layer {name} is {mu.item()}, std is {std.item()}')  
            
            hook.remove()
            
        if isinstance(module, (torch.nn.Hardtanh, ModdedHardTanh)):
            module.min_val = 0
            module.max_val = bin_val

In [None]:
print(features['model.conv1'][0,:,0,0] - server.model.model[2].bias)
print(features['model.conv1'][:,0,0,0] - server.model.model[2].bias[0])

In [None]:
def plot_map(feature_map):
    min_val, max_val = feature_map.amin(dim=[2,3], keepdim=True), feature_map.amax(dim=[2,3], keepdim=True)
    renorm_map = (feature_map - min_val) / (max_val - min_val)
    print(renorm_map[0, :3].permute(1, 2, 0).detach().cpu().shape)
    plt.imshow(renorm_map[0, :3].permute(1, 2, 0).detach().cpu())
    plt.show()

In [None]:
def plot_grid_features(data):
    grid_shape = int(torch.as_tensor(data.shape[0]).sqrt().ceil())
    s = 10
    fig, axes = plt.subplots(grid_shape, grid_shape, figsize=(s, s))
    label_classes = []
    min_val, max_val = data.amin(dim=[1,2], keepdim=True), data.amax(dim=[1,2], keepdim=True)
    # data = (data - min_val) / (data - min_val)
    for i, (im, axis) in enumerate(zip(data, axes.flatten())):
        axis.imshow(im.cpu())
        axis.axis('off')

In [None]:
plot_grid_features(features['model.conv1'][0])

In [None]:
plot_map(in_features['model.conv0'])
plot_map(in_features['model.conv1'].sum(dim=1, keepdim=True))
plot_map(in_features['model.conv2'].sum(dim=1, keepdim=True))
plot_map(in_features['model.conv3'].sum(dim=1, keepdim=True))

In [None]:
attacker.model_template = copy.deepcopy(server.model)
user.model = copy.deepcopy(server.model)
user.model

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  

In [None]:
[(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

In [None]:
# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, setup, order_batch=False)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, setup, order_batch=True)
ordered_user_data = dict(data=reconstructed_user_data['data'][metrics['order']],
                         labels=reconstructed_user_data['labels'])

In [None]:
user.plot(ordered_user_data)

PSNR without parameter modifications and 10 data points: 16-17