In [1]:
epochs = 5

# Example - Simple Vertically Partitioned Split Neural Network

- <b>Alice</b>
    - Has model Segment 1
    - Has the handwritten Images
- <b>Bob</b>
    - Has model Segment 2
    - Has the image Labels
    
Based on [SplitNN - Tutorial 3](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/advanced/split_neural_network/Tutorial%203%20-%20Folded%20Split%20Neural%20Network.ipynb) from Adam J Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL) · GitHub:  [@H4LL](https://github.com/H4LL)

Authors:
- Pavlos Papadopoulos · GitHub:  [@pavlos-p](https://github.com/pavlos-p)
- Tom Titcombe · GitHub:  [@TTitcombe](https://github.com/TTitcombe)



In [2]:
class SplitNN:
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers

        self.data = []
        self.remote_tensors = []

    def forward(self, x):
        data = []
        remote_tensors = []

        data.append(models[0](x))

        if data[-1].location == models[1].location:
            remote_tensors.append(data[-1].detach().requires_grad_())
        else:
            remote_tensors.append(
                data[-1].detach().move(models[1].location).requires_grad_()
            )

        i = 1
        while i < (len(models) - 1):
            data.append(models[i](remote_tensors[-1]))

            if data[-1].location == models[i + 1].location:
                remote_tensors.append(data[-1].detach().requires_grad_())
            else:
                remote_tensors.append(
                    data[-1].detach().move(models[i + 1].location).requires_grad_()
                )

            i += 1

        data.append(models[i](remote_tensors[-1]))

        self.data = data
        self.remote_tensors = remote_tensors

        return data[-1]

    def backward(self):
        data = self.data
        remote_tensors = self.remote_tensors

        i = len(models) - 2
        while i > -1:
            if remote_tensors[i].location == data[i].location:
                grads = remote_tensors[i].grad.copy()
            else:
                grads = remote_tensors[i].grad.copy().move(data[i].location)

            data[i].backward(grads)
            i -= 1

    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()

    def step(self):
        for opt in self.optimizers:
            opt.step()

In [3]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from src.dataloader import PartitionDistributingDataLoader
from src.dataset import add_ids, partition_dataset



hook = sy.TorchHook(torch)


Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/home/pavlito/miniconda3/envs/pyvertical-dev/lib/python3.7/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.3.so'


In [4]:
# # Data preprocessing
# transform = transforms.Compose([transforms.ToTensor(),
#                               transforms.Normalize((0.5,), (0.5,)),
#                               ])
# trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Create dataset
data = add_ids(MNIST)(".", download=True, transform=ToTensor())  # add_ids adds unique IDs to data points

# Split data
data_partition1, data_partition2 = partition_dataset(data)

# Batch data
dataloader = PartitionDistributingDataLoader(data_partition1, data_partition2, batch_size=128)


In [5]:
torch.manual_seed(0)

# Define our model segments

input_size = 784
hidden_sizes = [128, 640]
output_size = 10

models = [
    nn.Sequential(
        nn.Linear(input_size, hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], hidden_sizes[1]),
        nn.ReLU(),
    ),
    nn.Sequential(nn.Linear(hidden_sizes[1], output_size), nn.LogSoftmax(dim=1)),
]

# Create optimisers for each segment and link to them
optimizers = [
    optim.SGD(model.parameters(), lr=0.03,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")

# Send Model Segments to model locations
model_locations = [alice, bob]
for model, location in zip(models, model_locations):
    model.send(location)

#Instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN = SplitNN(models, optimizers)

In [6]:
def train(x, target, splitNN):
    
    #1) Zero our grads
    splitNN.zero_grads()
    
    #2) Make a prediction
    pred = splitNN.forward(x)
    
    #3) Figure out how much we missed by
    criterion = nn.NLLLoss()
    loss = criterion(pred, target)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the nework
    splitNN.backward()
    
    #6) Change the weights
    splitNN.step()
    
    return loss

In [7]:
# for i in range(epochs):
#     running_loss = 0
#     for images, labels in trainloader:
#         images = images.send(models[0].location)
#         images = images.view(images.shape[0], -1)
#         labels = labels.send(models[-1].location)
#         loss = train(images, labels, splitNN)
#         running_loss += loss.get()

#     else:
#         print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))
        
for i in range(epochs):
    running_loss = 0        
    for (data, ids1), (labels, ids2) in dataloader:
        # Train a model
        data = data.send(models[0].location)
        data = data.view(data.shape[0], -1)
        labels = labels.send(models[-1].location)
        loss = train(data, labels, splitNN)
        running_loss += loss.get()

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(dataloader)))

ValueError: Expected input batch_size (34) to match target batch_size (35).

In [8]:
print("Labels pointing to: ", labels)
print("Images pointing to: ", data)

Labels pointing to:  (Wrapper)>[PointerTensor | me:68884086353 -> bob:32543195190]
Images pointing to:  (Wrapper)>[PointerTensor | me:13999244539 -> alice:82324319634]


In [9]:
data.shape

torch.Size([34, 784])

In [10]:
labels.shape

torch.Size([35])