In [None]:
epochs = 15

# Tutorial 4 - SplitNN for Vertically Partitioned Data

<b>Recap:</b> The previous tutorial looked at building a basic SplitNN, where an NN was split into two segments on two seperate hosts. However, what if clients have multi-modal multi-institutional collaboration?

<b>What is Vertically Partitioned Data? </b> Data is said to be vertically partitioned when several organizations own different attributes or modalities of information for the same set of entities.

<b>Why use Partitioned Data? </b> Partition allows for orgnizations holding different modalities of data to learn distributed models without data sharing. Partitioning scheme is traditionally used to reduce the size of data by splitting and distribute to each client

<b>Description: </b>This configuration allows for multiple clients holding different modalities of data to learn distributed models without data sharing. As a concrete example we walkthrough the case where radiology centers collaborate with pathology test centers and a server for disease diagnosis. Radiology centers holding imaging data modalities train a partial model upto the cut layer. In the same way the pathology test center having patient test results trains a partial model upto its own cut layer. The outputs at the cut layer from both these centers are then concatenated and sent to the disease diagnosis server that trains the rest of the model. This process is continued back and forth to complete the forward and backward propagations in order to train the distributed deep learning model without sharing each others raw data. In this tutorial, we split a single flatten image into two segments to mimic different modalities of data, you can also split it into arbitrary number.


<img src="images/SplitNN_Veritically.png" width="20%">


In this tutorial, we demonstrate the SplitNN class with a 3 segment distribution [[1](https://arxiv.org/abs/1812.00564)]. This time;

- <b>$Alice_{0}$</b>
    - Has Model Segment 1
    - Has the handwritten images segment 1
- <b>$Alice_{1}$</b>
    - Has model Segment 2
    - Has the handwritten images segment 2
- <b>$Bob$</b> 
    - Has Model Segment 3
    - Has the image labels

Author:
- Haofan Wang - github：[@haofanwang](https://github.com/haofanwang)

In [None]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [None]:
class SplitNN(torch.nn.Module):
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers
        self.outputs = [None]*len(self.models)
        super().__init__()
        
    def forward(self, x):
        # save output of each client
        for i in range(len(self.models)-1):
          self.outputs[i] = self.models[i](x[i])
        # concatenate outputs from clients
        self.concatenated_input = torch.cat([self.outputs[i].copy().get() for i in range(len(self.outputs)-1)],dim=1)
        self.concatenated_input = self.concatenated_input.detach().send(self.models[-1].location).requires_grad_()
        # make a prediction on server model using recieved signal
        self.outputs[-1] = self.models[-1](self.concatenated_input)
        return self.outputs[-1]
    
    def backward(self):
        for i in range(len(self.models)-1):
          if i == 0:
            grad_a = self.concatenated_input.grad[:,0:self.outputs[0].copy().get().size()[1]].copy().move(self.models[i].location)
          else:
            grad_a = self.concatenated_input.grad[:,self.outputs[i-1].copy().get().size()[1]:self.outputs[i-1].copy().get().size()[1]+self.outputs[i].copy().get().size()[1]].copy().move(self.models[i].location)
          self.outputs[i].backward(grad_a)
    
    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()
        
    def step(self):
        for opt in self.optimizers:
            opt.step()
    
    def train(self):
        for model in self.models:
            model.train()
    
    def eval(self):
        for model in self.models:
            model.eval()
            
    @property
    def location(self):
        return self.models[0].location if self.models and len(self.models) else None

In [None]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [None]:
torch.manual_seed(0)

# define the input size for each client, each client may have different input size
# you can also try arbitrary number of vertical splits, [size_1, size_2, ... size_n] where n is the total number of segments
input_sizes = [392, 392]

# define the hidden layer size, each client can have different hidden layer size, all clients share the same size for convinence
hidden_sizes = [128, 320]

# define the input size for the server
concatenated_size = hidden_sizes[-1]*len(input_sizes)

# define the output size, it should be the number of classes
output_size = 10

# define clients models, you can define any model here, but remember to make input sizes consistent with the size of splitted segments
client_models = []
for i in range(len(input_sizes)):
  client_models.append(
      nn.Sequential(
                nn.Linear(input_sizes[i], hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    )
  )

# define server model
server_model = [
    nn.Sequential(
                nn.Linear(concatenated_size, output_size),
                nn.LogSoftmax(dim=1),
    )               
]

# complete model array
models = client_models + server_model

# create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create client workers
client_locations = []
for i in range(len(input_sizes)):
  alice = sy.VirtualWorker(hook, id="alice_"+str(i))
  client_locations.append(alice)

# create server worker
server_location = []
bob = sy.VirtualWorker(hook, id="bob")
server_location.append(bob)

# send model segments to starting locations
model_locations = client_locations + server_location

for model, location in zip(models, model_locations):
    model.send(location)

# instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN =  SplitNN(models, optimizers)

In [None]:
def train(images, target, splitNN):

      #1) zero our grads
      splitNN.zero_grads()

      #2) make a prediction
      pred = splitNN.forward(images)

      #3) calculate the loss
      criterion = nn.NLLLoss()
      loss = criterion(pred, labels)

      #4) backprop the loss on the end layer
      loss.backward()

      #5) feed gradients backward through the nework
      splitNN.backward()

      #6) update the weights
      splitNN.step()
      
      return loss

In [None]:
for i in range(epochs):
    running_loss = 0
    splitNN.train()
    for images, labels in trainloader:
        # flatten data
        images = images.view(images.shape[0],-1)
        # splitted image segments
        split_images = []
        for i in range(len(input_sizes)):
          if i == 0:
            split_image = images[:,0:input_sizes[i]].send(model_locations[i])
          else:
            split_image = images[:,input_sizes[i-1]:input_sizes[i-1]+input_sizes[i]].send(model_locations[i])
          split_image = split_image.view(images.shape[0], -1)
          split_images.append(split_image)
        # image labels
        labels = labels.send(bob)
        loss = train(split_images, labels, splitNN)
        running_loss += loss.get()

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

In [None]:
def test(model, dataloader, dataset_name):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in dataloader:
            # flatten data
            data = data.view(data.shape[0], -1)
            # splitted image segments
            split_images = []
            for i in range(len(input_sizes)):
              if i == 0:
                split_image = data[:,0:input_sizes[i]].send(model_locations[i])
              else:
                split_image = data[:,input_sizes[i-1]:input_sizes[i-1]+input_sizes[i]].send(model_locations[i])
              split_image = split_image.view(data.shape[0], -1)
              split_images.append(split_image)

            output = model(split_images).get()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()

    print("{}: Accuracy {}/{} ({:.0f}%)".format(dataset_name, 
                                                correct,
                                                len(dataloader.dataset), 
                                                100. * correct / len(dataloader.dataset)))

In [None]:
testset = datasets.MNIST('mnist', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
test(splitNN, testloader, "Test set")