# Robbing the Fed: Directly Obtaining Private Data in Federated Learning with Modified Models

This notebook shows an example for the threat model and attack described in "Robbing the Fed: Directly Obtaining Private Data in Federated Learning with Modified Models". This example deviates from the other "honest-but-curious" server models and investigates an actively malicious model. As such, the attack applies to any model architecture, but its impact is more or less obvious (or not at all) depending on the already present architecture onto which the malicious "Imprint" block is grafted.

In this notebook, we place the block in front of a ResNet18, as an example. The attack can also be conceptualized as merely a "malicious parameters" attack against a model which contains, for example, only fully connected layers or only convolutions without strides and fully connected layers.


Paper URL: https://openreview.net/forum?id=fwzUgo0FM9v

This variation implements the one shot attack shown in Fig. 3 of the paper. We aggregate 16,384 data points here, but this number is only limited by the amount of time you are willing to wait for the aggregated gradients to be computed (Aside from numerical issues which are likeyl to appear at astronomical batch sizes). 

The expected number of cases in which this attacks recovers a unique data point *is always 1/e ≈ 37%*. As such please also do not be surprised when repeating the steps in this notebook does not work the first time.

### Startup

In [None]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [None]:
cfg = breaching.get_config(overrides=["attack=imprint", "case=8_industry_scale_fl", 
                                      "case/server=malicious-model-rtf"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [None]:
users = 256
cfg.case.user.num_data_points = 64 # How many data points per use 
cfg.case.data.examples_from_split = 'training' # We have to draw images from the ImageNet training set

cfg.case.server.model_modification.type = 'OneShotBlock'
cfg.case.server.model_modification.num_bins = cfg.case.user.num_data_points * users

In [None]:
cfg.case.server.model_modification.num_bins = 128 # How many bins are in the block


# How does the block interact with the model?
# The block can be placed later in the model given a position such as  '4.0.conv':
cfg.case.server.model_modification.position = None  # None defaults to the first layer
# The block can also be connected in various ways to the other layers:
cfg.case.server.model_modification.connection = 'addition'

cfg.case.server.model_modification.linfunc = 'fourier' 
cfg.case.server.model_modification.mode = 32

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis.

In [None]:
server_payload = server.distribute_payload()
shared_data, _ = user.compute_local_updates(server_payload)  

In [None]:
# user.plot(true_user_data) # Not a great idea for this batch size

### Reconstruct user data:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

For this attack, we also share secret information from the malicious server with the attack (`server.secrets`), which here is the location and structure of the imprint block.

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], server.secrets, 
                                                      dryrun=cfg.dryrun)

In [None]:
found_data = dict(data = reconstructed_user_data['data'][1:2], labels=None)
user.plot(found_data)

#### Identify which user image this is 

In [None]:
matches = dict()
for idx, (data, label) in enumerate(user.dataloader.dataset):
    matches[idx] = torch.dist(found_data['data'], data.to(**setup))
    if matches[idx] < 1:
        break
    if idx % 1000 == 0:
        print(f'Currently at index {idx}')
idx = min(matches, key=matches.get)
print(idx)
true_data = user.dataloader.dataset[idx]
matching_user_data = dict(data = true_data[0][None,...], labels=true_data[1])
user.plot(matching_user_data)

This saves a lot of compute as we can now compute metrics for the matching data point only:

In [None]:
metrics = breaching.analysis.report(found_data, matching_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

### Notes:
*  `OneShotBlockSparse` is also an option, which has only half as many parameters, but requires a HardTanh to be constructed over two layers.