# Fishing for User Data in Large-Batch Federated Learning via Gradient Magnification


This notebook shows an example for a **arbitrary batch image gradient inversion** as described in "Fishing for User Data in Large-Batch Federated Learning via Gradient Magnification". The setting is a pretrained ResNet-18 and the federated learning algorithm is **fedSGD** in a **cross-device** setting.

Paper URL: https://arxiv.org/abs/2202.00580

This variant fishes for user data of the target user by estimating the feature distribution based on a group of other users. This is especially practical in a cross-device setting, where the server has access to many users.

#### Abstract
Federated learning (FL) has rapidly risen in popularity due to its promise of privacy and efficiency. Previous works have exposed privacy vulnerabilities in the FL pipeline by recovering user data from gradient updates. However, existing attacks fail to address realistic settings because they either 1) require a `toy' settings with very small batch sizes, or 2) require unrealistic and conspicuous architecture modifications. We introduce a new strategy that dramatically elevates existing attacks to operate on batches of arbitrarily large size, and without architectural modifications. Our model-agnostic strategy only requires modifications to the model parameters sent to the user, which is a realistic threat model in many scenarios. We demonstrate the strategy in challenging large-scale settings, obtaining high-fidelity data extraction in both cross-device and cross-silo federated learning.

### Startup

In [None]:
try:
    import breaching
except ModuleNotFoundError:
    # You only really need this safety net if you want to run these notebooks directly in the examples directory
    # Don't worry about this if you installed the package or moved the notebook to the main directory.
    import os; os.chdir("..")
    import breaching
    
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf, open_dict
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will load the full configuration object. This includes the configuration for the use case and threat model as `cfg.case` and the hyperparameters and implementation of the attack as `cfg.attack`. All parameters can be modified below, or overriden with `overrides=` as if they were cmd-line arguments.

In [None]:
cfg = breaching.get_config(overrides=["case/server=malicious-fishing", "attack=april_analytic"])
          
device = torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations for the attack, or the case:

In [None]:
# In principle the attack can work with a normal split like this:
cfg.case.data.name = "ImageNet"
cfg.case.data.examples_from_split = "validation"
cfg.case.data.default_clients = 25
cfg.case.server.target_cls_idx = 0 # Which class to attack?

cfg.case.data.partition="balanced"

cfg.case.server.pretrained=False # this notebook "often" works when the model is also pretrained...

cfg.case.user.num_data_points = 8

cfg.case.user.user_idx = 0
cfg.case.user.provide_labels = True # Mostly out of convenience

cfg.case.model="vit_small_april"

In [None]:
cfg.case.server.class_multiplier = 0.5
cfg.case.server.bias_multiplier = 0
cfg.case.server.feat_multiplier = 400

cfg.case.server.reweight_collisions = 1
cfg.case.server.reset_param_weights = False

### Instantiate all parties

The following lines generate "server, "user" and "attacker" objects and print an overview of their configurations.

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

### Simulate an attacked FL protocol

In this scenario, other users also exist, which we simulate below:

In [None]:
additional_users = []
for user_idx in range(1, cfg.case.data.default_clients): # The target user is user 0
    cfg.case.user.user_idx = user_idx
    extra_user = breaching.cases.construct_user(model, loss_fn, cfg.case, setup)
    additional_users += [extra_user]

We then run a modified server protocol, which first finds the feature by querying the other users and attacks the user ith a modified parameter vector based on the feature distribution gauged from the other users:

In [None]:
[shared_data], [server_payload], true_user_data = server.run_protocol(user, additional_users)

In [None]:
user.plot(true_user_data)

#### We can also evaluate the measured feature distribution:

In [None]:
plt.hist(true_user_data["distribution"]);

In [None]:
server_payload["parameters"][-2][server_payload["parameters"][-2] != 0]

In [None]:
server_payload["parameters"][-1][0:2]

In [None]:
shared_data["gradients"][-1][0:2]

In [None]:
from breaching.cases.malicious_modifications.classattack_utils import print_gradients_norm, cal_single_gradients
single_gradients, single_losses = cal_single_gradients(user.model, loss_fn, true_user_data, setup=setup)
print_gradients_norm(single_gradients, single_losses)

### Now reconstruct  a single "fished" user data point:

Now we launch the attack, reconstructing user data based on only the `server_payload` and the `shared_data`. 

You can interrupt the computation early to see a partial solution.

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

Next we'll evaluate metrics, comparing the `reconstructed_user_data` to the `true_user_data`.

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

And finally, we also plot the reconstructed data:

In [None]:
user.plot(reconstructed_user_data)

### Notes:
* You can use `cal_single_gradients` and `print_gradients_norm` from `breaching.cases.classattack_utils` to verify that only one of the user data points has a non-neglible gradient norm
* This attack has a $1/e \approx 37\%$ success chance for a single target user. Do not be alarmed if it does not work immediately (In those casese the reconstruction may return NaN immediately). In the cross-device setting, the attack can be deployed against a large number of users.
* This example shows the attack in a (fast to compute) setting where each user has only 4 goldfish images. You can also launch this attack in the general setting where each user has a large amount of data (and for example, the same number of goldfish images among them, or also much more), by tweaking the data settings and waiting longer. 