# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=[])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
    print(OmegaConf.to_yaml(cfg.attack))
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
# cfg.dryrun = False
# cfg.attack.optim.step_size=1.0
cfg.case.user.data_idx = 0
cfg.attack.optim.signed=False

### Instantiate all parties

In [None]:
user, server = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)

In [None]:
print(user)
print(server)
print(attacker)

### Simulate an attacked FL protocol

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# True user data is returned only for analysis

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct(server_payload, server.secrets, shared_data, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, 
                                    true_user_data, server_payload, server.model, setup)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, server_payload, server.model, setup)

In [None]:
import skimage.feature  # Lazy metric stuff import
import skimage.measure
import skimage.transform
import matplotlib.pyplot as plt
descriptor_extractor = skimage.feature.ORB(n_keypoints=800)
from breaching.analysis.metrics import psnr_compute

In [None]:
dm = torch.as_tensor(server_payload['data'].mean, **setup)[None, :, None, None]
ds = torch.as_tensor(server_payload['data'].std, **setup)[None, :, None, None]

In [None]:
rec_denormalized = torch.clamp(reconstructed_user_data['data'] * ds + dm, 0, 1).cpu()
ground_truth_denormalized = torch.clamp(true_user_data['data'] * ds + dm, 0, 1).cpu()

In [None]:
img_np, ref_np = rec_denormalized[0].numpy(), ground_truth_denormalized[0].numpy()  # move to numpy
descriptor_extractor.detect_and_extract(ref_np.mean(axis=0))  # and grayscale for ORB
keypoints_src, descriptors_src = descriptor_extractor.keypoints, descriptor_extractor.descriptors
descriptor_extractor.detect_and_extract(img_np.mean(axis=0))
keypoints_tgt, descriptors_tgt = descriptor_extractor.keypoints, descriptor_extractor.descriptors

matches = skimage.feature.match_descriptors(descriptors_src, descriptors_tgt, cross_check=True)
# Look for an affine transform and search with RANSAC over matches:
model_robust, inliers = skimage.measure.ransac((keypoints_tgt[matches[:, 1]],
                                               keypoints_src[matches[:, 0]]), skimage.transform.EuclideanTransform,
                                               min_samples=len(matches)-1, residual_threshold=4, max_trials=2500)
warped_img = skimage.transform.warp(img_np, model_robust, output_shape=ref_np.shape)


In [None]:
model_robust

In [None]:
warped1 = skimage.transform.warp(img_np.transpose(1,2,0), model_robust, mode='wrap', order=1)
plt.imshow(warped1)
psnr_compute(torch.as_tensor(ref_np.transpose(1, 2, 0)).contiguous(), torch.as_tensor(warped1), batched=True)

In [None]:
tform = skimage.transform.EuclideanTransform()
tform.estimate(keypoints_tgt[matches[:, 1]], keypoints_src[matches[:, 0]])
tform

warped2 = skimage.transform.warp(img_np.transpose(1,2,0), tform, mode='wrap')
plt.imshow(warped2)
psnr_compute(torch.as_tensor(ref_np.transpose(1, 2, 0)).contiguous(), torch.as_tensor(warped2), batched=True)

In [None]:
plt.imshow(ref_np.transpose(1, 2, 0))

In [None]:
import skimage.registration

In [None]:
shift, error, diffphase = skimage.registration.phase_cross_correlation(img_np.mean(axis=0), 
                                                                       ref_np.mean(axis=0), upsample_factor=10)

In [None]:
shift

In [None]:
plt.imshow(img_np.transpose(1, 2, 0))

In [None]:
tform = skimage.transform.EuclideanTransform(translation=-shift)
warped_fft = skimage.transform.warp(img_np.transpose(1, 2, 0), tform, order=3, mode='wrap', 
                                    preserve_range=False)
plt.imshow(warped_fft)
psnr_compute(torch.as_tensor(ref_np.transpose(1, 2, 0)).contiguous(), torch.as_tensor(warped_fft), batched=True)