In [1]:
epochs = 4

# Simplified Split Neural Network (Chipirones)


In this example, computation graphs have been reworked to be distributed in PySyft. This means we can do away with the cumbersome backwards method. The data has been reworked to be a simple dummy dataset. This makes debugging wayy easier as we add horizontal distributions.

<b>Description: </b> Here we fold a multilayer SplitNN in on itself in order to accomodate the data nd labels being in the same place. We demonstrate the SplitNN class with a 3 segment distribution. This time,

<img src="images/FoldedNN.png" width="20%">

- <b>Alice</b>
    - Has Model Segment 1
    - Has Model Segment 3
    - Has the handwritten images
    - Has the image labels
- <b>Bob</b>
    - Has model Segment 2
    
Again, we use the exact same model as we used in the previous tutorial and see the same accuracy. Neither Alice nor Bob have the full model and Bob can't see Alice's data. 

Author:
- Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub:  [@H4LL](https://github.com/H4LL)

In [2]:
class SplitNN:
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers
        
    def forward(self, x):
        a = []
        
        a.append(models[0](x))
        if a[-1].location == models[1].location:
            a[-1] = a[-1].detach().requires_grad_()
        else:
            a[-1] = a[-1].detach().move(models[1].location).requires_grad_()

        i=1    
        while i < (len(models)-1):    
            a.append(models[i](a[-1]))
            if a[-1].location == models[i+1].location:
                a[-1] = a[-1].detach().requires_grad_()
            else:
                a[-1] = a[-1].detach().move(models[i+1].location).requires_grad_() 
            i+=1
        
        a.append(models[i](a[-1]))
        self.a = a
        
        return a[-1]

    
    def zero_grads(self):
        for opt in optimizers:
            opt.zero_grad()
        
    def step(self):
        for opt in optimizers:
            opt.step()

In [3]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
claire = sy.VirtualWorker(hook, id="claire")

In [4]:
# A Toy Dataset
data = torch.tensor([[0,0,0,0],[0,1,0,0],[1,0,0,0],[1,1,0,0.]], requires_grad=True)
labels = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

data = data.send(alice)
labels = labels.send(claire)

In [5]:
torch.manual_seed(0)

# Define our model segments

input_size = 4
hidden_sizes = [3, 2]
output_size = 1

models = [
    nn.Linear(input_size, hidden_sizes[0]),
    nn.Linear(hidden_sizes[0], hidden_sizes[1]),
    nn.Linear(hidden_sizes[1], output_size)
    ]

# Create optimisers for each segment and link to them
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# Send Model Segments to model locations
model_locations = [alice, bob, claire]
for model, location in zip(models, model_locations):
    model.send(location)

#Instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN = SplitNN(models, optimizers)

In [6]:
def train(x, target, splitNN):
    
    #1) Zero our grads
    splitNN.zero_grads()
    
    #2) Make a prediction
    pred = splitNN.forward(x)
    
    print(pred.location)
    
    #3) Figure out how much we missed by
    loss = ((pred - target)**2).sum()

    #4) Backprop the loss on the end layer
    loss.backward()
    
    #6) Change the weights
    splitNN.step()
    
    return loss

In [7]:
for i in range(epochs):
    running_loss = 0
    loss = train(data, labels, splitNN)
    running_loss += loss.get()

#     else:
    print("Epoch {} - Training loss: {}".format(i, running_loss))

<VirtualWorker id:claire #objects:5>
Epoch 0 - Training loss: 1.3407090902328491
<VirtualWorker id:claire #objects:5>
Epoch 1 - Training loss: 1.2212690114974976
<VirtualWorker id:claire #objects:5>
Epoch 2 - Training loss: 1.1569126844406128
<VirtualWorker id:claire #objects:5>
Epoch 3 - Training loss: 1.1220707893371582
