# Split Learning with CIFAR-10

### Setup

Install the required packages for training in the current Jupyter kernel:

Set `PYTHONPATH` to include custom files of this example and some reused files from the [CIFAR-10](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10) examples:

In [1]:
import sys
# !{sys.executable} -m pip install -r ./requirements.txt

In [2]:
import os
sys.path.append(os.path.join(os.getcwd(), "src"))
sys.path.append(os.path.join(os.getcwd(), "..", "..", "examples", "advanced", "cifar10"))

In [3]:
try:
    from oneshotVFL.cifar10_vertical_data_splitter import Cifar10VerticalDataSplitter
except ImportError as e:
     raise ImportError("PYTHONPATH is not set properly") from e

## 1. Download and split the CIFAR-10 dataset
To simulate a vertical split dataset, we first download the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and distribute it between the two clients, assuming an `OVERLAP` of 10,000 samples between the two clients' datasets.

In [4]:
SPLIT_DIR = "/tmp/cifar10_vert_splits"
OVERLAP = "10000"
%run ./cifar10_split_data_vertical.py --split_dir $SPLIT_DIR --overlap $OVERLAP

ModuleNotFoundError: No module named 'splitnn'

## 2. Run private set intersection
We are using NVFlare's FL simulator to run the following experiments.

In order to find the overlapping data indices between the different clients participating in split learning, 
we randomly select an subset of the training indices.

In [5]:
import os
from nvflare import SimulatorRunner

simulator = SimulatorRunner(
    job_folder=f"jobs/cifar10_psi",
    workspace="/tmp/nvflare/cifar10_psi",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-06-08 04:31:53,148 - SimulatorRunner - INFO - Create the Simulator Server.
2023-06-08 04:31:53,151 - Cell - INFO - server: creating listener on tcp://0:44685
2023-06-08 04:31:53,153 - Cell - INFO - server: created backbone external listener for tcp://0:44685
2023-06-08 04:31:53,154 - ConnectorManager - INFO - 37350: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2023-06-08 04:31:53,155 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:53061] is starting
2023-06-08 04:31:53,657 - Cell - INFO - server: created backbone internal listener for tcp://localhost:53061
2023-06-08 04:31:53,659 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:44685] is starting
2023-06-08 04:31:53,740 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 41251
2023-06-08 04:31:53,741 - SimulatorRunner - INFO - Deploy the Apps.
2023-06-08 04:31:53,745 - SimulatorRunner - INFO - Create the simulat

The result will be saved on each client's working directory in `intersection.txt`.

We can check the correctness of the result by comparing to the generate ground truth overlap, saved in `overlap.npy`.

### Check the PSI result
We can check the correctness of the result by comparing to the generate ground truth overlap, saved in overlap.npy.

In [6]:
import os
import numpy as np

gt_overlap = np.load(os.path.join(SPLIT_DIR, "overlap.npy"))

psi_overlap_1 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt")
psi_overlap_2 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-2/psi/intersection.txt")
                     
print("gt_overlap", gt_overlap, f"n={len(gt_overlap)}")
print("psi_overlap_1", psi_overlap_1, f"n={len(psi_overlap_1)}")
print("psi_overlap_2", psi_overlap_2, f"n={len(psi_overlap_2)}")

intersect_1 = np.intersect1d(psi_overlap_1, gt_overlap, assume_unique=True)
intersect_2 = np.intersect1d(psi_overlap_2, gt_overlap, assume_unique=True)

print(f"Found {100*len(intersect_1)/len(gt_overlap):.1f}% of the overlapping sample ids for site-1.")
print(f"Found {100*len(intersect_2)/len(gt_overlap):.1f}% of the overlapping sample ids for site-2.")

gt_overlap [11841 19602 45519 ... 47278 37020  2217] n=10000
psi_overlap_1 [ 4481. 45431. 46253. ... 34846.   179.  7277.] n=10000
psi_overlap_2 [38639. 10733. 31911. ... 12172. 46167.   865.] n=10000
Found 100.0% of the overlapping sample ids for site-1.
Found 100.0% of the overlapping sample ids for site-2.


## 3. Run simulated Oneshot VFL experiments
Next we use the `intersection.txt` files to align the datasets on each participating site in order to do split learning.
The [config_fed_client.json](./jobs/cifar10_splitnn/site-1/config/config_fed_client.json) takes as input the previously generated intersection file for each site.
```
    {
        "id": "cifar10-learner",
        "path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
        "args": {
            "dataset_root": "{DATASET_ROOT}",
            "intersection_file": "{INTERSECTION_FILE}",
            "lr": 1e-2,
            "model": {"path": "pt.networks.split_nn.SplitNN", "args":  {"split_id":  0}},
            "timeit": true
        }
    }
```
To set the filename automatically, run:

In [4]:
!for i in {1..2}; \
do \
  CONFIG_FILE=jobs/cifar10_oneshotVFL/site-${i}/config/config_fed_client.json; \
  INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-${i}/psi/intersection.txt; \
  python3 ./set_intersection_file.py --config_file ${CONFIG_FILE} --intersection_file ${INTERSECTION_FILE}; \
done

Modified jobs/cifar10_oneshotVFL/site-1/config/config_fed_client.json to use INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt
Modified jobs/cifar10_oneshotVFL/site-2/config/config_fed_client.json to use INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-2/psi/intersection.txt


To run the experiment, execute:

In [6]:
import os
from nvflare import SimulatorRunner

simulator = SimulatorRunner(
    job_folder=f"jobs/cifar10_oneshotVFL",
    workspace="/tmp/nvflare/cifar10_oneshotVFL",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-06-09 20:22:01,186 - SimulatorRunner - INFO - Create the Simulator Server.
2023-06-09 20:22:01,190 - Cell - INFO - server: creating listener on tcp://0:46593
2023-06-09 20:22:01,191 - Cell - INFO - server: created backbone external listener for tcp://0:46593
2023-06-09 20:22:01,192 - ConnectorManager - INFO - 122013: Try start_listener Listener resources: {'secure': False, 'host': 'localhost'}
2023-06-09 20:22:01,193 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00002 PASSIVE tcp://0:5177] is starting
2023-06-09 20:22:01,694 - Cell - INFO - server: created backbone internal listener for tcp://localhost:5177
2023-06-09 20:22:01,696 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector [CH00001 PASSIVE tcp://0:46593] is starting
2023-06-09 20:22:01,775 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 35687
2023-06-09 20:22:01,776 - SimulatorRunner - INFO - Deploy the Apps.
2023-06-09 20:22:01,780 - SimulatorRunner - INFO - Create the simulate