# Split Learning with CIFAR-10

## 1. Download and split the CIFAR-10 dataset
To simulate a vertical split dataset, we first download the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and distribute it between the two clients.

In [6]:
%env SPLIT_DIR=/tmp/cifar10_vert_splits
%env OVERLAP=10000
!python3 ./cifar10_split_data_vertical.py --split_dir ${SPLIT_DIR} --overlap ${OVERLAP}

env: SPLIT_DIR=/tmp/cifar10_vert_splits
env: OVERLAP=10000
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: Partition CIFAR-10 dataset into vertically with 10000 overlapping samples.
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar10/cifar-10-python.tar.gz
100.0%
Extracting /tmp/cifar10/cifar-10-python.tar.gz to /tmp/cifar10
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/overlap.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-1.npy
INFO:Cifar10VerticalDataSplitter:[identity=local, run=_]: save /tmp/cifar10_vert_splits/site-2.npy


## 2. Run private set intersection
We are using NVFlare's FL simulator to run the following experiments.

In order to find the overlapping data indices between the different clients participating in split learning, 
we randomly select an subset of the training indices.

In [7]:
import os
#from nvflare import SimulatorRunner
from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner

simulator = SimulatorRunner(
    job_folder=f"job_configs/cifar10_psi",
    workspace="/tmp/nvflare/cifar10_psi",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-02-13 19:33:37,341 - SimulatorRunner - INFO - Create the Simulator Server.
2023-02-13 19:33:37,447 - nvflare.fuel.hci.server.hci - INFO - Starting Admin Server localhost on Port 43373
2023-02-13 19:33:37,453 - SimulatorServer - INFO - starting insecure server at localhost:36789
2023-02-13 19:33:37,456 - SimulatorRunner - INFO - Deploy the Apps.
2023-02-13 19:33:37,459 - SimulatorRunner - INFO - Create the simulate clients.
2023-02-13 19:33:37,506 - ClientManager - INFO - Client: New client site-1@127.0.0.1 joined. Sent token: 58824fe4-990b-4ae1-8355-272c98e8a74b.  Total clients: 1
2023-02-13 19:33:37,509 - FederatedClient - INFO - Successfully registered client:site-1 for project simulator_server. Token:58824fe4-990b-4ae1-8355-272c98e8a74b SSID:
2023-02-13 19:33:37,551 - ClientManager - INFO - Client: New client site-2@127.0.0.1 joined. Sent token: 429e493d-7565-45e5-a7fa-a9021decb438.  Total clients: 2
2023-02-13 19:33:37,555 - FederatedClient - INFO - Successfully registered cli

  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)


2023-02-13 19:33:38,659 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: Server runner starting ...
2023-02-13 19:33:38,661 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: starting workflow DhPSIController (<class 'nvflare.app_common.workflows.dh_psi_controller.DhPSIController'>) ...
2023-02-13 19:33:38,663 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job]: Workflow DhPSIController (<class 'nvflare.app_common.workflows.dh_psi_controller.DhPSIController'>) started
2023-02-13 19:33:38,664 - DhPSIController - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController]: PSI control flow started.
2023-02-13 19:33:38,665 - DhPSIController - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController]: start pre workflow
2023-02-13 19:33:38,666 - DhPSIWorkFlow - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController]: pre_process on task PSI
2023-02-13 19:33:38,668 - DhPSIController - IN

E0213 19:33:40.609877062   72784 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)
  return _bootstrap._gcd_import(name[level:], package, level)

2023-02-13 19:33:42,763 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-1, peer_run=simulate_job, task_name=PSI, task_id=d9c01c72-8954-4ac0-9af4-de37876e7d85]: assigned task to client site-1: name=PSI, id=d9c01c72-8954-4ac0-9af4-de37876e7d85
2023-02-13 19:33:42,770 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-1, peer_run=simulate_job, task_name=PSI, task_id=d9c01c72-8954-4ac0-9af4-de37876e7d85]: sent task assignment to client
2023-02-13 19:33:42,772 - SimulatorServer - INFO - GetTask: Return task: PSI to client: site-1 (58824fe4-990b-4ae1-8355-272c98e8a74b) 
2023-02-13 19:33:42,774 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=DhPSIController, peer=site-2, peer_run=simulate_job, task_name=PSI, task_id=ed5b48be-0b5a-4297-b6b7-1c664ca8e568]: assigned task to client site-2: name=PSI, id=ed5b48be-0b5a-4297-b6b7-1c664ca8e568
2023-02-13 19:33:42,777 - ServerRu

The result will be saved on each client's working directory in `intersection.txt`.

We can check the correctness of the result by comparing to the generate ground truth overlap, saved in `overlap.npy`.

### Check the PSI result
We can check the correctness of the result by comparing to the generate ground truth overlap, saved in overlap.npy.

In [8]:
import os
import numpy as np

split_dir = os.environ["SPLIT_DIR"]
gt_overlap = np.load(os.path.join(split_dir, "overlap.npy"))

psi_overlap_1 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt")
psi_overlap_2 = np.loadtxt("/tmp/nvflare/cifar10_psi/simulate_job/site-2/psi/intersection.txt")
                     
print("gt_overlap", gt_overlap, f"n={len(gt_overlap)}")
print("psi_overlap_1", psi_overlap_1, f"n={len(psi_overlap_1)}")
print("psi_overlap_2", psi_overlap_2, f"n={len(psi_overlap_2)}")

intersect_1 = np.intersect1d(psi_overlap_1, gt_overlap, assume_unique=True)
intersect_2 = np.intersect1d(psi_overlap_2, gt_overlap, assume_unique=True)

print(f"Found {100*len(intersect_1)/len(gt_overlap):.1f}% of the overlapping sample ids for site-1.")
print(f"Found {100*len(intersect_2)/len(gt_overlap):.1f}% of the overlapping sample ids for site-2.")

gt_overlap [11841 19602 45519 ... 47278 37020  2217] n=10000
psi_overlap_1 [ 4481. 45431. 46253. ... 34846.   179.  7277.] n=10000
psi_overlap_2 [38639. 10733. 31911. ... 12172. 46167.   865.] n=10000
Found 100.0% of the overlapping sample ids for site-1.
Found 100.0% of the overlapping sample ids for site-2.


## 3. Run simulated split-learning experiments
Next we use the `intersection.txt` files to align the datasets on each participating site in order to do split learning.
The [config_fed_client.json](./job_configs/cifar10_splitnn/site-1/config/config_fed_client.json) takes as input the previously generated intersection file for each site.
```
    {
        "id": "cifar10-learner",
        "path": "pt.learners.cifar10_learner_splitnn.CIFAR10LearnerSplitNN",
        "args": {
            "dataset_root": "{DATASET_ROOT}",
            "intersection_file": "{INTERSECTION_FILE}",
            "lr": 1e-2,
            "model": {"path": "pt.networks.split_nn.SplitNN", "args":  {"split_id":  0}},
            "timeit": true
        }
    }
```
To set the filename automatically, run:

In [5]:
!for i in {1..2}; \
do \
  CONFIG_FILE=job_configs/cifar10_splitnn/site-${i}/config/config_fed_client.json; \
  INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-${i}/psi/intersection.txt; \
  python3 ./set_intersection_file.py --config_file ${CONFIG_FILE} --intersection_file ${INTERSECTION_FILE}; \
done

Modified job_configs/cifar10_splitnn/site-1/config/config_fed_client.json to use INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-1/psi/intersection.txt
Modified job_configs/cifar10_splitnn/site-2/config/config_fed_client.json to use INTERSECTION_FILE=/tmp/nvflare/cifar10_psi/simulate_job/site-2/psi/intersection.txt


To run the experiment, execute:

In [6]:
import os
from nvflare import SimulatorRunner    

simulator = SimulatorRunner(
    job_folder=f"job_configs/cifar10_splitnn",
    workspace="/tmp/nvflare/cifar10_splitnn",
    n_clients=2,
    threads=2
)
run_status = simulator.run()
print("Simulator finished with run_status", run_status)

2023-02-01 11:21:52,457 - SimulatorRunner - INFO - Create the Simulator Server.
2023-02-01 11:21:52,479 - Cell - INFO - server: creating listener on grpc://localhost:53913
2023-02-01 11:21:52,481 - Cell - INFO - server: created backbone external listener for grpc://localhost:53913
2023-02-01 11:21:52,482 - ConnectorManager - INFO - 349396: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 11:21:52,483 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:35a77933-6e70-4677-9204-9b715b7ba6d1 is starting in PASSIVE mode
2023-02-01 11:21:52,985 - Cell - INFO - server: created backbone internal listener for tcp://localhost:30342
2023-02-01 11:21:52,989 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector AioGrpcDriver:707ec2a5-2630-4497-ac3f-a125e9ff4d1b is starting in PASSIVE mode
2023-02-01 11:21:52,993 - nvflare.fuel.f3.communicator - INFO - Communicator is started for local endpoint: server
2023-02-01 11:21:53,

E0201 11:21:57.617844691  349556 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers
E0201 11:21:57.634208997  349557 fork_posix.cc:76]           Other threads are currently calling into gRPC, skipping fork() handlers


2023-02-01 11:21:58,720 - Cell - INFO - site-1.simulate_job: created backbone internal connector to tcp://localhost:39242 on parent
2023-02-01 11:21:58,721 - ConnectorManager - INFO - 349565: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 11:21:58,721 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:a736b5be-2f44-4c08-b49f-f12a5a5dc528 is starting in PASSIVE mode
2023-02-01 11:21:58,732 - Cell - INFO - site-2.simulate_job: created backbone internal connector to tcp://localhost:37420 on parent
2023-02-01 11:21:58,733 - ConnectorManager - INFO - 349566: Try start_listener Listener resources: {'secure': False, 'host': 'localhost', 'ports': ['30000-40000']}
2023-02-01 11:21:58,733 - nvflare.fuel.f3.sfm.conn_manager - INFO - Connector TcpDriver:0442591d-0aff-4908-8e48-3e800181502c is starting in PASSIVE mode
2023-02-01 11:21:59,222 - Cell - INFO - site-1.simulate_job: created backbone internal listener for tc

The site containing the labels can compute accuracy and losses, which can be visualized in tensorboard.

In [9]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir /tmp/nvflare/cifar10_splitnn

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 349977), started 0:00:00 ago. (Use '!kill 349977' to kill it.)