## Split Learning with CIFAR-10

If you haven't already, please follow the steps in the [PSI](./federated_private_set_intersection.ipynb) example to prepare the data.

Now, that we have the data intersections, we can start with the actual [split learning](https://arxiv.org/abs/1810.06060).

Again, we use the  CIFAR-10 dataset. We assume one client holds the images, and the other client holds the labels to compute losses and accuracy metrics. 
Activations and corresponding gradients are being exchanged between the clients using NVFlare.

<img src="./figs/split_learning.svg" alt="Split learning setup" width="300"/>

### Implementation

To impliement a "SplitCNN" for split learning, we that a standard CNN for CIFAR-10 classification ([ModerateCNN](./src/splitnn/split_nn.py)) and split its forward/backward pass into two parts.

1. Convolutional layers
2. Fully connected layers

The convolutional layers are only optimized on the client holding the images, while the fully connected layers are optimized on the client holding the labels. For details see the [SplitNN](./src/splitnn/split_nn.py) code.

```python
class SplitNN(ModerateCNN):
    def __init__(self, split_id):
        ...
        if self.split_id == 0:
            self.split_forward = self.conv_layer
        elif self.split_id == 1:
            self.split_forward = self.fc_layer
        else:
            ...

    def forward(self, x):
        x = self.split_forward(x)
        return x
```

### Peer-to-peer Communication

To enable direct **peer-to-peer** communication between the clients, we will utilize NVFlare's low-level communication API. The [CIFAR10LearnerSplitNN](./src/splitnn/cifar10_learner_splitnn.py) class handles the execution of the forward & backward pass depending `split_id` specified by the client ID (see the Job API configuration below).

In order to proceed the with the split learning, the client holding the images ("site-1") needs to send activations to the client holding the corresponding labels ("site-2"). We can use [Aux channel](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.private.aux_runner.html#nvflare.private.aux_runner.AuxRunner.send_aux_request), i.e., `engine.send_aux_request()`, to directly pass that information between the clients, i.e., implement a peer-to-peer communication channel. The result of the request will include the gradients computed in the backward pass from "site-2" which will allow "site-1" one to continue the optimization of ther part of the SplitNN.

```python
            # send to other side
            result = engine.send_aux_request(
                targets=self.other_client,
                topic=SplitNNConstants.TASK_TRAIN_LABEL_STEP,
                request=data_shareable,
                timeout=SplitNNConstants.TIMEOUT,
                fl_ctx=fl_ctx,
            )
```

Note, each Aux request needs to register and topic handler on the receiving side. See the `initialize()` routine in [CIFAR10LearnerSplitNN](./src/splitnn/cifar10_learner_splitnn.py) for details.

```python
engine.register_aux_message_handler(
                topic=SplitNNConstants.TASK_TRAIN_LABEL_STEP, message_handle_func=self._aux_train_label_side
            )
```

See [Chapter 9: Implementing peer-to-peer (P2P) communication](../../../chapter-9_flare_low_level_apis/09.3_p2p_communication/p2p_communication.ipynb) for more details on using Aux channels.

### Run simulated split-learning experiments
Next we use the `intersection.txt` files computed in the previous step to align the datasets on each participating site in order to do split learning.

Using the Job API, we can define the previously generated intersection file as input for each site.


In [5]:
from nvflare.job_config.api import FedJob

# nvflare components
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_common.workflows.splitnn_workflow import SplitNNController
from nvflare.app_common.executors.splitnn_learner_executor import SplitNNLearnerExecutor

# custom code for this example
from src.splitnn.split_nn import ModerateCNN, SplitNN
from src.splitnn.cifar10_learner_splitnn import CIFAR10LearnerSplitNN

num_rounds = 15625
batch_size = 64

job = FedJob(name="cifar10_splitnn")

# add server components
job.to_server(
        SplitNNController(
            num_rounds=num_rounds,
            batch_size=batch_size,
            start_round=0,
            persistor_id=job.as_id(PTFileModelPersistor(model=ModerateCNN())),
            task_timeout=0,
            shareable_generator_id=job.as_id(FullModelShareableGenerator())
    )
)
job.to_server(ValidationJsonGenerator(), id="json_generator")

# add client components for two sites
n_clients = 2
for i in range(n_clients):
        site_name = f"site-{i+1}"

        learner_id = job.as_id(
                CIFAR10LearnerSplitNN(
                        dataset_root="/tmp/cifar10",
                        intersection_file=f"/tmp/nvflare/cifar10_psi/{site_name}/simulate_job/{site_name}/psi/intersection.txt",
                        lr=0.01,
                        model=SplitNN(split_id=i)
                )
        )

        job.to(SplitNNLearnerExecutor(learner_id=learner_id), site_name, tasks=["_splitnn_task_init_model_", "_splitnn_task_train_"])

job.export_job("/tmp/nvflare/jobs/job_config")

To run the experiment, execute:

In [None]:
job.simulator_run("/tmp/nvflare/cifar10_splitnn")

The site containing the labels can compute accuracy and losses, which can be visualized in tensorboard.

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir /tmp/nvflare/cifar10_splitnn

The resulting training and validation curves with an overlap of 10,000 samples is shown below. The training should take about half an hour to complete on a A100 GPU.

![Split learning training curves](./figs/sl_training_curve_o10000.png)

## Summary

This example demonstrates a complete split learning workflow using the CIFAR-10 dataset, consisting of two main parts:

### Part 1: Private Set Intersection (PSI)
- Implements PSI using ECDH, Bloom Filters, and Golomb Compressed Sets algorithms.
- Used to find overlapping data indices between two clients.
- Each client holds different parts of the CIFAR-10 dataset (images vs labels).
- PSI helps identify the common samples that can be used for training.
- Results in `intersection.txt` files containing the overlapping sample indices.

### Part 2: Split Learning
- Implements split learning where one client holds images and another holds labels.
- Uses the intersection indices from PSI to align datasets.
- Activations and gradients are exchanged between clients via NVFlare.
- Training progress can be monitored through TensorBoard.

Now, let's [recap](../../07.3_recap/recap.ipynb) what you learned in this chapter.