In [None]:
epochs = 15

# Tutorial 4 - SplitNN for Vertically Partitioned Data

<b>Recap: </b> The previous tutorial looked at building a basic SplitNN, where an NN was split into two or multiple segments on two or multiple seperate hosts. However, what if clients have multi-modal multi-institutional collaboration?

<b>What is a SplitNN? </b>The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In the case, both parties can train the model without knowing each others data or full details of the model.

<b>Why use a SplitNN? </b>The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [[2](https://arxiv.org/abs/1812.00564)].

<b>Advantages </b>
- The accuracy should be identical to a non-split version of the same model, trained locally. 
- the model is distributed, meaning all segment holders must consent in order to aggregate the model at the end of training.
- The scalability of this approach, in terms of both network and computational resources, could make this an a valid alternative to FL and LBSGD, particularly on low power devices.
- This could be an effective mechanism for both horizontal and vertical data distributions.
- As computational cost is already quite low, the cost of applying homomorphic encryption is also minimised.
- Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values.

<b>Constraints </b>
- A new technique with little surroundung literature, a large amount of comparison and evaluation is still to be performed.
- This approach requires all hosts to remain online during the entire learning process (less fesible for hand-held devices).
- Not as established in privacy-preserving toolkits as FL and LBSGD.
- Activation signals and their corresponding gradients still have the capacity to leak information, however this is yet to be fully addressed in the literature.

<b>Description: </b>This configuration allows for multiple clients holding different modalities of data to learn distributed models without data sharing. As a concrete example we walkthrough the case where radiology centers collaborate with pathology test centers and a server for disease diagnosis. Radiology centers holding imaging data modalities train a partial model upto the cut layer. In the same way the pathology test center having patient test results trains a partial model upto its own cut layer. The outputs at the cut layer from both these centers are then concatenated and sent to the disease diagnosis server that trains the rest of the model. This process is continued back and forth to complete the forward and backward propagations in order to train the distributed deep learning model without sharing each others raw data. In this tutorial, we split a single flatten image into two segments to mimic different modalities of data, you can also split it into arbitrary number.


<img src="images/SplitNN_Veritically.png" width="20%">


In this tutorial, we demonstrate the SplitNN class with a 3 segment distribution [[1](https://arxiv.org/abs/1812.00564)]. This time;

- <b>$Alice_{0}$</b>
    - Has Model Segment 1
    - Has the handwritten images Segment 1
- <b>$Alice_{1}$</b>
    - Has model Segment 2
    - Has the handwritten images Segment 2
- <b>$Bob$</b> 
    - Has Model Segment 3
    - Has the image labels
    

Author:
- Haofan Wang - github：[@haofanwang](https://github.com/haofanwang)

In [None]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [None]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [None]:
torch.manual_seed(0)

# define the input size for each client, each client may have different input size
# you can also try arbitrary number of vertical splits, [size_1, size_2, ... size_n] where n is the total number of segments
input_sizes = [392, 392]

# define the hidden layer size, each client can have different hidden layer size, all clients share the same size for convinence
hidden_sizes = [128, 320]

# define the input size for the server
concatenated_size = hidden_sizes[-1]*len(input_sizes)

# define the output size, it should be the number of classes
output_size = 10

# define clients models, you can define any model here, but remember to make input sizes consistent with the size of splitted segments
client_models = []
for i in range(len(input_sizes)):
  client_models.append(
      nn.Sequential(
                nn.Linear(input_sizes[i], hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    )
  )

# define server model
server_model = [
    nn.Sequential(
                nn.Linear(concatenated_size, output_size),
                nn.LogSoftmax(dim=1),
    )               
]

# complete model array
models = client_models + server_model

# create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create client workers
client_locations = []
for i in range(len(input_sizes)):
  alice = sy.VirtualWorker(hook, id="alice_"+str(i))
  client_locations.append(alice)

# create server worker
server_location = []
bob = sy.VirtualWorker(hook, id="bob")
server_location.append(bob)

# Send Model Segments to starting locations
model_locations = client_locations + server_location

for model, location in zip(models, model_locations):
    model.send(location)

In [None]:
def train(images, target, models, optimizers):
      # Training Logic

      #1) erase previous gradients (if they exist)
      for opt in optimizers:
        opt.zero_grad()

      #2) define a empty tensor to save concatenated output
      remote_a = torch.zeros(images[0].shape[0],concatenated_size)

      #3) make a prediction on each client model
      output_a = []
      for i in range(len(input_sizes)):
        a = models[i](images[i])
        output_a.append(a)
        remote_a[:,hidden_sizes[1]*i:hidden_sizes[1]*(i+1)] = a.copy().get()
      remote_a = remote_a.detach().send(models[-1].location).requires_grad_()

      #4) make prediction on server model using recieved signal
      pred = models[-1](remote_a)

      #5) calculate the loss
      criterion = nn.NLLLoss()
      loss = criterion(pred, labels)

      #6) backpropagate on bottom models given this gradient
      loss.backward()

      #7) send gradient of the recieved activation signal to the model behind
      for i in range(len(input_sizes)):
        grad_a = remote_a.grad[:,hidden_sizes[1]*i:hidden_sizes[1]*(i+1)].copy().move(models[i].location)
        output_a[i].backward(grad_a)

      #8) update params
      for opt in optimizers:
          opt.step()
      
      #9) print our progress
      return loss.detach().get()

In [None]:
for i in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        # flatten images
        images = images.view(images.shape[0],-1)

        # splitted image segments
        split_images = []
        for i in range(len(input_sizes)):
          if i == 0:
            split_image = images[:,0:input_sizes[i]].send(model_locations[i])
          else:
            split_image = images[:,input_sizes[i-1]:input_sizes[i-1]+input_sizes[i]].send(model_locations[i])
          split_image = split_image.view(images.shape[0], -1)
          split_images.append(split_image)

        # image labels
        labels = labels.send(bob)

        loss = train(split_images, labels, models, optimizers)
        running_loss += loss

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

In [1]:
def test(models, dataloader, dataset_name):
    for model in models:
      model.eval()
    correct = 0
    with torch.no_grad():
        for images, target in testloader:
          # flatten images
          images = images.view(images.shape[0],-1)
          # splitted image segments
          split_images = []
          for i in range(len(input_sizes)):
            if i == 0:
              split_image = images[:,0:input_sizes[i]].send(model_locations[i])
            else:
              split_image = images[:,input_sizes[i-1]:input_sizes[i-1]+input_sizes[i]].send(model_locations[i])
            split_image = split_image.view(images.shape[0], -1)
            split_images.append(split_image)
          # define a empty tensor to save concatenated output
          remote_a = torch.zeros(split_images[0].shape[0],concatenated_size)
          # make a prediction on each client model
          output_a = []
          for i in range(len(input_sizes)):
            a = models[i](split_images[i])
            output_a.append(a)
            remote_a[:,hidden_sizes[1]*i:hidden_sizes[1]*(i+1)] = a.copy().get()
          remote_a = remote_a.detach().send(models[-1].location).requires_grad_()
          # make prediction on server model using recieved signal
          output = models[-1](remote_a).get()
          # calculate the the number of correct predictions
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).sum()
    
    print("{}: Accuracy {}/{} ({:.0f}%)".format(dataset_name, 
                                                correct,
                                                len(dataloader.dataset), 
                                                100. * correct / len(dataloader.dataset)))

In [None]:
testset = datasets.MNIST('mnist', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
test(models, testloader, "Test set")