In [0]:
epochs = 15

# Tutorial 4 - SplitNN for Vertically Partitioned Data

<b>Recap:</b> The previous tutorial looked at building a basic SplitNN, where an NN was split into two segments on two seperate hosts. However, what if clients have multi-modal multi-institutional collaboration?

<b>Description: </b>Here we simply use two same images to represent the multi-modal data. We demonstrate the SplitNN class with a 3 segment distribution. This time,



<img src="images/SplitNN_Veritically.png" width="20%">


In this tutorial, we demonstrate the SplitNN class with a 3 segment distribution [[1](https://arxiv.org/abs/1812.00564)]. This time;

- <b>$Alice_{0}$</b>
    - Has Model Segment 1
    - Has the handwritten images
- <b>$Alice_{1}$</b>
    - Has model Segment 2
    - Has the handwritten images
- <b>$Bob$</b> 
    - Has Model Segment 3
    - Has the image labels
    
We use the exact same model as we used in the previous tutorial, only this time we have two clients and one host.


Author:
- Haofan Wang - github：[@haofanwang](https://github.com/haofanwang)

In [0]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

In [0]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [0]:
torch.manual_seed(0)

# Define our model segments

input_size = 784
hidden_sizes = [128, 320, 640]
output_size = 10

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[2], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice_0 = sy.VirtualWorker(hook, id="alice_0")
alice_1 = sy.VirtualWorker(hook, id="alice_1")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice_0, alice_1, bob

# Send Model Segments to starting locations
model_locations = [alice_0, alice_1, bob]

for model, location in zip(models, model_locations):
    model.send(location)

In [0]:
def train(images_0, images_1, target, models, optimizers):
      # Training Logic

      #1) erase previous gradients (if they exist)
      for opt in optimizers:
        opt.zero_grad()

      #2) make a prediction
      a_0 = models[0](images_0)
      a_1 = models[1](images_1)

      #3) break the computation graph link, and send the activation signal to the next model
      remote_a_0 = a_0.detach().move(models[2].location).requires_grad_()
      remote_a_1 = a_1.detach().move(models[2].location).requires_grad_()
      remote_a = torch.zeros(images.shape[0],hidden_sizes[2])
      remote_a[:,:hidden_sizes[1]] = remote_a_0.copy().get()
      remote_a[:,hidden_sizes[1]:] = remote_a_1.copy().get()
      remote_a = remote_a.detach().send(models[2].location).requires_grad_()

      #4) make prediction on next model using recieved signal
      pred = models[2](remote_a)

      #5) calculate how much we missed
      criterion = nn.NLLLoss()
      loss = criterion(pred, labels)

      #6) figure out which weights caused us to miss
      loss.backward()

      #7) send gradient of the recieved activation signal to the model behind
      grad_a_0 = remote_a.grad[:,:hidden_sizes[1]].copy().move(models[0].location)
      grad_a_1 = remote_a.grad[:,hidden_sizes[1]:hidden_sizes[2]].copy().move(models[1].location)

      #8) backpropagate on bottom model given this gradient
      a_0.backward(grad_a_0)
      a_1.backward(grad_a_1)

      #9) change the weights
      for opt in optimizers:
          opt.step()
      
      #10) print our progress
      return loss.detach().get()

In [8]:
for i in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        images_0 = images.send(alice_0)
        images_0 = images_0.view(images_0.shape[0], -1)
        images_1 = images.send(alice_1)
        images_1 = images_1.view(images_1.shape[0], -1)
        labels = labels.send(bob)

        loss = train(images_0, images_1, labels, models, optimizers)
        running_loss += loss

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

Epoch 0 - Training loss: 0.5656647682189941
Epoch 1 - Training loss: 0.2735804319381714
Epoch 2 - Training loss: 0.21353811025619507
Epoch 3 - Training loss: 0.17315588891506195
Epoch 4 - Training loss: 0.1451188325881958
Epoch 5 - Training loss: 0.12480196356773376
Epoch 6 - Training loss: 0.10766370594501495
Epoch 7 - Training loss: 0.09647190570831299
Epoch 8 - Training loss: 0.0861542671918869
Epoch 9 - Training loss: 0.077510304749012
Epoch 10 - Training loss: 0.0705472007393837
Epoch 11 - Training loss: 0.06403899192810059
Epoch 12 - Training loss: 0.05939096584916115
Epoch 13 - Training loss: 0.05365265905857086
Epoch 14 - Training loss: 0.04996408149600029


In [0]:
def test(models, dataloader, dataset_name):
    for model in models:
      model.eval()
    correct = 0
    with torch.no_grad():
        for images, target in testloader:
          images_0 = images.send(alice_0)
          images_0 = images_0.view(images_0.shape[0], -1)
          images_1 = images.send(alice_1)
          images_1 = images_1.view(images_1.shape[0], -1)
          a_0 = models[0](images_0)
          a_1 = models[1](images_1)
          remote_a_0 = a_0.detach().move(models[2].location).requires_grad_()
          remote_a_1 = a_1.detach().move(models[2].location).requires_grad_()
          remote_a = torch.zeros(images.shape[0],hidden_sizes[2])
          remote_a[:,:hidden_sizes[1]] = remote_a_0.copy().get()
          remote_a[:,hidden_sizes[1]:] = remote_a_1.copy().get()
          remote_a = remote_a.detach().send(models[2].location).requires_grad_()
          output = models[2](remote_a).get()
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).sum()
    
    print("{}: Accuracy {}/{} ({:.0f}%)".format(dataset_name, 
                                                correct,
                                                len(dataloader.dataset), 
                                                100. * correct / len(dataloader.dataset)))

In [12]:
test(models, testloader, "Test set")

Test set: Accuracy 9738/10000 (97%)
