# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import copy
import math

In [None]:
def _return_model_features(model, inputs):
    features = dict()  # The named-hook + dict construction should be a bit more robust
    if inputs.ndim == 3:
        inputs = inputs.unsqueeze(0)

    def named_hook(name):
        def hook_fn(module, input, output):
            features[name] = input[0]
        return hook_fn
    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, (torch.nn.Hardtanh)):
            hook = module.register_forward_hook(named_hook(name))
            feature_layer_name = name
            break
    model(inputs)
    hook.remove()
    return features[feature_layer_name]

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['attack=invertinggradients',
                                                      'case=1_single_image_small'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda:2') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.data_idx = 0
cfg.case.model='ConvNetSmall'

cfg.case.user.num_data_points = 1

cfg.case.data.batch_size = 512
cfg.case.server.has_external_data = True

cfg.attack.objective.type='masked-cosine-similarity'
# The total variation scale should be small for CIFAR images
cfg.attack.regularization.total_variation.scale = 1e-5

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
server.model.to(**setup)

In [None]:
print(user)
print(server)
print(attacker)

## Malicious server I : Modify the model architecture first

In [None]:
num_paths = 10

In [None]:
feature_dim = server.model.model[-5].out_channels
server.model.model[-1] = torch.nn.Sequential(torch.nn.Linear(feature_dim, num_paths),
                                             torch.nn.Hardtanh(min_val=0, max_val=1), 
                                             torch.nn.Linear(num_paths, feature_dim),
                                             server.model.model[-1]).to(**setup)


attacker.model_template = copy.deepcopy(server.model)
user.model = copy.deepcopy(server.model)
user.model

## Malicious server II: Include paths

In [None]:
# Old first layer:
# new_weight = module.weight.new_zeros(module.in_channels, module.in_channels, *module.kernel_size)
# torch.nn.init.orthogonal_(new_weight)
# new_bias = module.bias.new_zeros(module.in_channels)
# fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(new_weight)
# torch.nn.init.uniform_(new_bias, -1 / math.sqrt(fan_in), 1 / math.sqrt(fan_in))

# # Replicate filters:
# replication_dim = module.out_channels // module.in_channels
# replicated_weight = torch.cat([new_weight] * replication_dim).contiguous()
# replicated_bias = torch.cat([new_bias] * replication_dim).contiguous()

# module.weight.data[:replication_dim * module.in_channels] = replicated_weight
# module.bias.data[:replication_dim * module.in_channels] = replicated_bias

In [None]:
input_path_width = 3
first_conv = True
with torch.no_grad():
    for name, module in server.model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            # Initialize existing params at zero:
            module.weight.data = torch.zeros_like(module.weight)
            module.bias.data = torch.zeros_like(module.bias)
            
            output_path_width = module.out_channels // num_paths
            
            new_weight = module.weight.new_zeros(output_path_width, input_path_width,  *module.kernel_size)
            torch.nn.init.orthogonal_(new_weight)

            new_bias = module.bias.new_zeros(output_path_width)
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(new_weight)
            torch.nn.init.uniform_(new_bias, -1 / math.sqrt(fan_in), 1 / math.sqrt(fan_in))

            # Group channels:
            ii, io = 0, 0 # index-in index-out
            if first_conv:
                for path in range(num_paths):
                    module.weight.data[io:io+output_path_width, :] = new_weight.clone()
                    io += output_path_width
                    ii += input_path_width
                first_conv = False
            else:
                for path in range(num_paths):
                    module.weight.data[io:io+output_path_width, ii:ii+input_path_width] = new_weight.clone()
                    io += output_path_width
                    ii += input_path_width
                
            module.bias.data[:output_path_width * num_paths] = torch.cat([new_bias] * num_paths).contiguous()
            
            # Set input->output
            input_path_width = output_path_width
            
            
            print(module.weight.shape, module.bias.shape)
            # Test channel:
            inputs = torch.cat([torch.randn(1, 1, 32, 32, **setup)] * module.in_channels, dim=1)
            feats = module(inputs)
            print(feats[0,0:4, 0, 0], feats[0,output_path_width:output_path_width+4, 0, 0])
        if isinstance(module, torch.nn.Linear) and module.out_features == num_paths:
            # prep averaging layer here
            module.weight.data = torch.zeros_like(module.weight.data)
            module.bias.data = torch.zeros_like(module.bias.data)
            new_block = module.weight.data.new_ones(input_path_width)/ input_path_width
            idx = 0
            for path in range(num_paths):
                module.weight.data[path, idx:idx+input_path_width] = new_block.clone()
                idx += input_path_width
            adaptation_layer = module
        if isinstance(module, torch.nn.Linear) and module.in_features == num_paths:
            # prep return layer here, all inputs need to be picked up
            # module.weight.data = torch.ones_like(module.weight.data) / num_paths
            #torch.nn.init.orthogonal_(module.weight.data)
            # module.bias.data = torch.zeros_like(module.bias.data)
            pass
            # dont mess with the return layer
            
num_params = sum([(p.abs() > 1e-7).sum() for p in server.model.parameters()])
linear_params = sum([(p.abs() > 1e-7).sum() for m in server.model.modules() for p in m.parameters()  
                     if isinstance(m, torch.nn.Linear)])
print(f'Model architecture {server.model.__class__} loaded with {num_params:,} non-zero parameters of which '
      f'{linear_params} are in linear layers.')

target_information = cfg.case.user.num_data_points * torch.as_tensor(cfg.case.data.shape).prod()

print(f'Overall this is a data ratio of {(num_params - linear_params) / target_information:2.2f}:1 '
      f'for target shape {[cfg.case.user.num_data_points, *cfg.case.data.shape]} if pathcount was optimal.')

In [None]:
inputs = torch.randn(1, 3, 32, 32, **setup)
feats = _return_model_features(server.model, inputs)
feats

# Compute bins and set feature distribution:

In [None]:
from statistics import NormalDist

def get_bins_by_mass(num_bins):
    bins = []
    mass = 0
    for path in range(num_bins + 1):
        mass += 1 / (num_bins + 2)
        bins += [NormalDist(mu=0, sigma=1).inv_cdf(mass)]
    bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)]
    return bins[:-1], bin_sizes

In [None]:
# get_bins_by_mass(10)

In [None]:
features = dict()
def named_hook(name):
    def hook_fn(module, input, output):
        features[name] = output
    return hook_fn

In [None]:
with torch.inference_mode():
    for name, module in server.model.named_modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            hook = module.register_forward_hook(named_hook(name))

            random_data_sample = next(iter(server.external_dataloader))[0].to(**setup)
            # random_data_sample = torch.randn(1024, 3, 32, 32, **setup)
            # random_data_sample = true_user_data['data'] #ground truth data sampâle for testing

            server.model(random_data_sample)
            std, mu = torch.std_mean(features[name])
            # print(f'Initial mean of layer {name} is {mu.item()}, std is {std.item()}')
            with torch.no_grad():
                module.weight.data /= std + 1e-8
                module.bias.data -= mu / (std  + 1e-8)
            
            server.model(random_data_sample)
            std, mu = torch.std_mean(features[name])
            print(f'Fixed mean of layer {name} is {mu.item()}, std is {std.item()}')  
            
            
            hook.remove()
            if isinstance(module, torch.nn.Linear) and module.out_features == num_paths:
                # Verify:
                print(f'Input to hardtanh before bias and scale is set: {features[name][0]}')

                adapt_module = module
                # Modify bins in this layer
                bins, bin_sizes = get_bins_by_mass(num_paths)
                # Safety wheels:
                #bins = [b * 2 for b in bins]
                #bin_sizes = [b * 2 for b in bin_sizes]
                #bins = torch.linspace(-1.96, 1.96, num_paths + 1)
                #bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)]
                #bins = bins[:-1]
                
                # Old mod:
                module.weight.data /= torch.as_tensor(bin_sizes, **setup)[:, None]
                module.bias.data -= torch.as_tensor(bins, **setup) 
                module.bias.data /= torch.as_tensor(bin_sizes, **setup)

                # New computation with extend bin extension?:
#                 I = -NormalDist(mu=0, sigma=1).inv_cdf(0.90)
#                 module.weight.data *= 2 * I / torch.as_tensor(bin_sizes, **setup)[:, None]
#                 module.bias.data -= torch.as_tensor(bins, **setup) 
#                 module.bias.data *= 2 * I / torch.as_tensor(bin_sizes, **setup)
#                 module.bias.data -= I
                break
                
            del features[name]

In [None]:
with torch.inference_mode():
    hook = adapt_module.register_forward_hook(named_hook('hardtanh_input'))
    random_data_sample = next(iter(server.external_dataloader))[0].to(**setup)
    # random_data_sample = torch.randn(1024, 3, 32, 32, **setup)
    # random_data_sample = true_user_data['data'] #ground truth data sample for testing

    server.model(random_data_sample)
    hook.remove()

In [None]:
print(features['hardtanh_input'][0])
threshold = torch.nn.functional.hardtanh(features['hardtanh_input'][0], min_val=0, max_val=1)
print(threshold)

### Threshold stats:

In [None]:
threshold = torch.nn.functional.hardtanh(features['hardtanh_input'], min_val=0, max_val=1)
print(((threshold != 1) & (threshold != 0)).sum() / random_data_sample.shape[0])
((threshold != 1) & (threshold != 0)).sum(dim=0)

In [None]:
del features

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  

true_user_data['data'].mean(), true_user_data['data'].std()

In [None]:
[(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, 
                                                      server.secrets, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
user.plot(reconstructed_user_data, scale=True)