## Split Learning with CIFAR-10: Private Set Intersection

This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the 
CIFAR-10 dataset and the [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html).

We assume one client holds the images, and the other client holds the labels to compute losses and accuracy metrics. 
Activations and corresponding gradients are being exchanged between the clients using NVFlare.

<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/>

## Private Set Intersection

In order to find the overlapping data indices between the different clients participating in split learning, 
we randomly select an subset of the training indices. For this, we can use a Private Set Intersection (PSI) technique. First of all, let's discuss what PSI is.

### What is PSI?

According to [Wikipedia](https://en.wikipedia.org/wiki/Private_set_intersection): "The Private set intersection is a
secure multiparty computation cryptographic technique that allows two parties holding sets to compare encrypted versions 
of these sets in order to compute the intersection. In this scenario, neither party reveals anything to the counterparty
except for the elements in the intersection."

![psi.png](./figs/psi.jpg)

### What's the use cases for PSI in federated learning?

There are many use cases for PSI, in terms of federated machine learning, we are particularly interested in the 
following use cases:

* **Vertical Learning** -- User IDs matching

![user_id_match.png](./figs/user_id_intersect.png)

* **Vertical Learning** -- Feature overlapping discovery
  - Site-1 : Feature A, B, C, D
  - Site-2: Feature E, A, C, F, X, Y, Z
  - Overlapping features: A, C

* **Federated Statistics** -- Distinct values count of categorical features 
  - feature = email address -> discover :  how many distinct emails in the email addresses
  - feature = country -> discover: how many distinct countries

  *Example*
    - site-1:   features: country.  total distinct countries = 20
    - site-2:   features: country,  total distinct countries = 100
    - site-1 and site2 overlapping distinct countries = 10  

  => Total distinct countries = 20 + 100 - Overlapping countries  = 120-10 = 110
  
In federated statistics use case, the PSI will be used inside the Federated Statistics operations.

For the example used in this chapter, Vertical FL or Split Learning user ID matching, we can directly do the PSI calculation as a preprocessing step with a separate NVFlare Job.

## PSI Protocol

There are many protocols that can be used for PSI.

For our implementation in nvflare/app_opt/psi, the PSI protocol is based on [ECDH](https://en.wikipedia.org/wiki/Elliptic-curve_Diffie%E2%80%93Hellman),
Bloom Filters, and Golomb Compressed Sets PSI algorithm.

The algorithm is developed by [openmined PSI](https://github.com/OpenMined/PSI) for two-party PSI.

We took the two-party direct communication PSI protocol and extended to Federated Computing setting where all exchanges are
funneled via a central FL server. We can also support multi-party PSI via a pair-wise approach, reducing the multiple intersection computations to several two-party PSI operation.

Please refer to [here](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/psi/README.md) for more details on the PSI protocol.

### Install requirements
If you haven't yet, install required packages.

In [None]:
!pip install --upgrade pip
!pip install -r ./requirements.txt

### Download and split the CIFAR-10 dataset
First, to simulate a vertical split dataset, we first download the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and distribute it between the two clients, assuming an `OVERLAP` of 10,000 samples between the two clients' datasets.

In [None]:
SPLIT_DIR = "/tmp/cifar10_vert_splits"
OVERLAP = "10000"
!python ./cifar10_split_data_vertical.py --split_dir $SPLIT_DIR --overlap $OVERLAP

### Run Job API

Now that we have prepared the data, let's use the Job API to create a PSI Job.

We need to a couple components specific to the PSI computation both on the server and client site, like `DhPSIController` and ``.

To start, we use a general `FedJob`.

In [None]:
from nvflare.job_config.api import FedJob

# nvflare components
from nvflare.app_common.psi.dh_psi.dh_psi_controller import DhPSIController
from nvflare.app_common.psi.psi_executor import PSIExecutor
from nvflare.app_opt.psi.dh_psi.dh_psi_task_handler import DhPSITaskHandler
from nvflare.app_common.psi.file_psi_writer import FilePSIWriter

# custom code for this example
from src.psi.cifar10_local_psi import Cifar10LocalPSI

job = FedJob(name="cifar10_psi")

# add server component
job.to_server(DhPSIController())

# add client components for two sites
n_clients = 2
for i in range(n_clients):
        site_name = f"site-{i+1}"

        # we add the client PSI components as ids to be referenced by other components
        psi_writer_id = job.as_id(FilePSIWriter(output_path="psi/intersection.txt"))
        local_psi_id = job.as_id(Cifar10LocalPSI(psi_writer_id=psi_writer_id, data_path=f"/tmp/cifar10_vert_splits/{site_name}.npy"))
        psi_algo_id = job.as_id(DhPSITaskHandler(local_psi_id=local_psi_id))

        # now, that we have all ids of requried components, we can add them to the client
        job.to(PSIExecutor(psi_algo_id=psi_algo_id), site_name)

        print(f"added components for {site_name}")

job.export_job("/tmp/nvflare/jobs/job_config")

Now, that we created the job, we can run the simulation:

In [None]:
job.simulator_run("/tmp/nvflare/cifar10_psi")

The result will be saved on each client's working directory in `intersection.txt`.

We can check the correctness of the result by comparing it to the generated ground truth overlap, saved in `overlap.npy`.

### Check the PSI result
We can check the correctness of the result by comparing to the generate ground truth overlap, saved in overlap.npy.

In [None]:
import os
import numpy as np

gt_overlap = np.load(os.path.join(SPLIT_DIR, "overlap.npy"))

psi_overlap_1 = np.loadtxt("/tmp/nvflare/cifar10_psi/site-1/simulate_job/site-1/psi/intersection.txt")
psi_overlap_2 = np.loadtxt("/tmp/nvflare/cifar10_psi/site-2/simulate_job/site-2/psi/intersection.txt")
                     
print("gt_overlap", gt_overlap, f"n={len(gt_overlap)}")
print("psi_overlap_1", psi_overlap_1, f"n={len(psi_overlap_1)}")
print("psi_overlap_2", psi_overlap_2, f"n={len(psi_overlap_2)}")

intersect_1 = np.intersect1d(psi_overlap_1, gt_overlap, assume_unique=True)
intersect_2 = np.intersect1d(psi_overlap_2, gt_overlap, assume_unique=True)

print(f"Found {100*len(intersect_1)/len(gt_overlap):.1f}% of the overlapping sample ids for site-1.")
print(f"Found {100*len(intersect_2)/len(gt_overlap):.1f}% of the overlapping sample ids for site-2.")

Next, we'll use the intersection indices computed by PSI in our [split learning](./split_learning.ipynb) example.