# Robbing the Fed: Directly Obtaining Private Data in Federated Learning with Modified Models

This notebook shows an example for the threat model and attack described in "Robbing the Fed: Directly Obtaining Private Data in Federated Learning with Modified Models". This example deviates from the other "honest-but-curious" server models and investigates an actively malicious model. As such, the attack applies to any model architecture, but its impact is more or less obvious (or not at all) depending on the already present architecture onto which the malicious "Imprint" block is grafted.

In this notebook, we place the block in front of a ResNet18, as an example.


Paper URL: https://openreview.net/forum?id=fwzUgo0FM9v

This variation implements an attack against user updating their models based on federated averaging. This threat model is only relevant if the user 

 * forces a federated averaging update with non-zero step size (but usually the server sets the step size)
 * forces an update on all model parameters in intermediate update steps (but the server might send different instructions)
 * has a model where the gradient signal coming to the imprint block is large enough to move these parameters.

Under these condiditions, the attack success is decreased, as bins drift away from their optimal choice, yet the attack still recovers large amounts of private data.

### Startup

In [None]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [None]:
cfg = breaching.get_config(overrides=["attack=imprint", "case/server=malicious-model-rtf", 
                                      "case/user=local_updates"])
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [None]:
cfg.case.user.num_data_points = 64 # How many data points does this user own

# FL setup:
cfg.case.user.num_local_updates = 8
cfg.case.user.num_data_per_local_update_step = 8
cfg.case.user.local_learning_rate = 1e-4


cfg.case.server.model_modification.type = 'SparseImprintBlock'
cfg.case.server.model_modification.num_bins = 128 # How many bins are in the block


# How does the block interact with the model?
# The block can be placed later in the model given a position such as  '4.0.conv':
cfg.case.server.model_modification.position = None  # None defaults to the first layer
# The block can also be connected in various ways to the other layers:
cfg.case.server.model_modification.connection = 'linear'



# Which linear measurement function should be used? 
# We know that the input dataset is already normalized as preprocessing step

# Knowing the distribution relatively well:
cfg.case.server.model_modification.linfunc = 'fourier' # works well for any normalized image data
cfg.case.server.model_modification.mode = 32

# # Eyeballing the distribution based on the law of large numbers:
# cfg.case.server.model_modification.linfunc = 'randn' # will work decently for anything
# cfg.case.server.model_modification.mode = None

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

This exchange is a simulation of a single query in a federated learning protocol. The server sends out a `server_payload` and the user computes an update based on their private local data. This user update is `shared_data` and contains, for example, the parameter gradient of the model in the simplest case. `true_user_data` is also returned by `.compute_local_updates`, but of course not forwarded to the server or attacker and only used for (our) analysis.

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

For this attack, we also share secret information from the malicious server with the attack (`server.secrets`), which here is the location and structure of the imprint block.

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], server.secrets, 
                                                      dryrun=cfg.dryrun)

Next we'll evaluate metrics, comparing the `reconstructed_user_data` to the `true_user_data`.

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

And finally, we also plot the reconstructed data:

In [None]:
user.plot(reconstructed_user_data)

### Notes:
* It is also ok to use the non-sparse default `ImprintBlock` here, but probably best to also modify the connection to be `addition` when doing so, to prevent the local model from diverging to NaN gradients before the attack has even started.
