# Breaching privacy

This notebook does the same job as the cmd-line tool `breach.py`, but also directly visualizes the user data and reconstruction

In [None]:
import torch
import hydra
from omegaconf import OmegaConf
%load_ext autoreload
%autoreload 2

import breaching
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()

### Initialize cfg object and system setup:

This will print out all configuration options. 
There are a lot of possible configurations, but there is usually no need to worry about most of these. Below, a few options are printed.

In [None]:
with hydra.initialize(config_path="config"):
    cfg = hydra.compose(config_name='cfg', overrides=['case=1_single_image_small'])
    print(f'Investigating use case {cfg.case.name} with server type {cfg.case.server.name}.')
    print('Attack settings are:')
          
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
setup

### Modify config options here

You can use `.attribute` access to modify any of these configurations:

In [None]:
cfg.case.user.data_idx = 1200
cfg.case.user.num_data_points = 1

cfg.attack.regularization.total_variation.scale=1e-5
cfg.attack.optim.max_iterations=8000

cfg.case.model='resnet20'

### Instantiate all parties

In [None]:
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)
attacker = breaching.attacks.prepare_attack(server.model, server.loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)

In [None]:
print(user)
print(server)
print(attacker)

### Simulate an attacked FL protocol

True user data is returned only for analysis

In [None]:
server_payload = server.distribute_payload()
shared_data, true_user_data = user.compute_local_updates(server_payload)  
# [(g.mean(), g.std()) for g in shared_data['gradients'][0]]

In [None]:
user.plot(true_user_data)

### Reconstruct user data:

In [None]:
reconstructed_user_data, stats = attacker.reconstruct([server_payload], [shared_data], 
                                                      server.secrets, dryrun=cfg.dryrun)

# How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, 
                                    server_payload, server.model, user.dataloader, setup=setup,
                                    order_batch=True, compute_full_iip=False)

In [None]:
user.plot(reconstructed_user_data)

In [None]:
#How good is the reconstruction?
metrics = breaching.analysis.report(reconstructed_user_data, true_user_data, [server_payload], 
                                    server.model, order_batch=True, compute_full_iip=False, 
                                    cfg_case=cfg.case, setup=setup)

# Kornia matching

In [None]:
from kornia.geometry import ImageRegistrator, homography_warp
import kornia.feature

In [None]:
img_src = reconstructed_user_data['data'].detach().clone()
img_dst = true_user_data['data'].detach().clone()
registrator = ImageRegistrator('similarity')
homo = registrator.register(img_dst, img_src)
homo

In [None]:
aligned_data = dict(data=homography_warp(img_src, homo, img_dst.shape[-2:]), labels=true_user_data['labels'])
aligned_data = dict(data=registrator.warp_src_into_dst(img_src), labels=true_user_data['labels'])


In [None]:
user.plot(aligned_data)

# LOFTR 

In [None]:
matcher = kornia.feature.LoFTR(pretrained="indoor")
#input = {"image0": img1, "image1": img2}
with torch.no_grad():
    correspondences_dict = matcher(dict(image0=img_dst.mean(dim=1, keepdim=True).flip(dims=[2,3]), 
                                        image1=img_dst.mean(dim=1, keepdim=True)))

In [None]:
correspondences_dict['keypoints0'].shape

In [None]:
import kornia.geometry.homography

In [None]:
homo = kornia.geometry.homography.find_homography_dlt(correspondences_dict['keypoints0'][None, ...], 
                                                      correspondences_dict['keypoints1'][None, ...])
homo

In [None]:
aligned_data = dict(data=homography_warp(img_dst.flip(dims=[2,3]), homo, img_dst.shape[-2:]), labels=true_user_data['labels'])

In [None]:
user.plot(aligned_data)

In [None]:
descriptor = kornia.feature.SIFTDescriptor(patch_size=224)
dscs1 = descriptor(img_src.mean(dim=1, keepdim=True))

In [None]:
dscs1.shape

# Kornia/RANSAC-Flow Test

In [None]:
nbScale = 7
coarseIter = 10000
coarsetolerance = 0.05
minSize = 400
imageNet = True # we can also use MOCO feature here
scaleR = 1.2 