<a href="https://colab.research.google.com/github/RAVIPATISRIVIDYA/devtraining-needit-madrid/blob/master/Split_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install syft==0.2.5

Collecting syft==0.2.5
  Downloading syft-0.2.5-py3-none-any.whl (369 kB)
[?25l[K     |▉                               | 10 kB 19.7 MB/s eta 0:00:01[K     |█▊                              | 20 kB 11.0 MB/s eta 0:00:01[K     |██▋                             | 30 kB 8.6 MB/s eta 0:00:01[K     |███▌                            | 40 kB 7.8 MB/s eta 0:00:01[K     |████▍                           | 51 kB 5.1 MB/s eta 0:00:01[K     |█████▎                          | 61 kB 5.6 MB/s eta 0:00:01[K     |██████▏                         | 71 kB 4.9 MB/s eta 0:00:01[K     |███████                         | 81 kB 5.5 MB/s eta 0:00:01[K     |████████                        | 92 kB 5.0 MB/s eta 0:00:01[K     |████████▉                       | 102 kB 5.2 MB/s eta 0:00:01[K     |█████████▊                      | 112 kB 5.2 MB/s eta 0:00:01[K     |██████████▋                     | 122 kB 5.2 MB/s eta 0:00:01[K     |███████████▌                    | 133 kB 5.2 MB/s eta 0:00:01[K

In [1]:
import torch
import syft as sy

# allow pysyft to work its magic on torch tensors
hook = sy.TorchHook(torch)

# create a virtual worker. in an actual setting this would be on a different machine
client = sy.VirtualWorker( hook, id='client' )

# define a tensor and send it to the client
x = torch.tensor([1,2,3,4,5])
# this leaves us with a pointer to the tensor
x_pointer = x.send( client )

# check out some meta data
print( x_pointer )
print( client._objects )

# we can use this pointers like normal tensors
result = x_pointer + x_pointer
print( result )

# if we want the result we can call get() to send the tensor back to us
result_local = result.get()
# once we call get() it removes the tensor from the other side and our pointer
# becomes invalid
print( result_local )
print( client._objects )
# print( result )

(Wrapper)>[PointerTensor | me:68580289521 -> client:89695094565]
{26639111457: <Plan Plan id:26639111457 owner:client Tags: #fss_eq_plan_1 built>
, 5442912878: <Plan Plan id:5442912878 owner:client Tags: #fss_eq_plan_2 built>
, 67438813281: <Plan Plan id:67438813281 owner:client Tags: #fss_comp_plan_1 built>
, 90905916870: <Plan Plan id:90905916870 owner:client Tags: #fss_comp_plan_2 built>
, 35310555869: <Plan Plan id:35310555869 owner:client Tags: #xor_add_1 built>
, 12468158228: <Plan Plan id:12468158228 owner:client Tags: #xor_add_2 built>
, 89695094565: tensor([1, 2, 3, 4, 5])}
(Wrapper)>[PointerTensor | me:36651760612 -> client:34002633088]
tensor([ 2,  4,  6,  8, 10])
{26639111457: <Plan Plan id:26639111457 owner:client Tags: #fss_eq_plan_1 built>
, 5442912878: <Plan Plan id:5442912878 owner:client Tags: #fss_eq_plan_2 built>
, 67438813281: <Plan Plan id:67438813281 owner:client Tags: #fss_comp_plan_1 built>
, 90905916870: <Plan Plan id:90905916870 owner:client Tags: #fss_comp_p

In [1]:
import torch
from torchvision import datasets, transforms  # it may raise errors, and you need restart the runtime
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)

epochs = 10

# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.CIFAR100('cifar100', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

torch.manual_seed(0)

# Define our model segments

input_size = 3072
hidden_sizes = [128, 640]
output_size = 100

models = [
    nn.Sequential(
                nn.Linear(input_size, hidden_sizes[0]),
                nn.ReLU(),
                nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                nn.ReLU(),
    ),
    nn.Sequential(
                nn.Linear(hidden_sizes[1], output_size),
                nn.LogSoftmax(dim=1)
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(model.parameters(), lr=0.07,)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob

# Send Model Segments to starting locations
model_locations = [alice, bob]

for model, location in zip(models, model_locations):
    model.send(location)

def train(x, target, models, optimizers):
    # Training Logic

    #1) erase previous gradients (if they exist)
    for opt in optimizers:
        opt.zero_grad()

    #2) make a prediction
    a = models[0](x)

    #3) break the computation graph link, and send the activation signal to the next model
    remote_a = a.move(models[1].location, requires_grad=True)

    #4) make prediction on next model using received signal
    pred = models[1](remote_a)

    #5) calculate how much we missed
    criterion = nn.NLLLoss()
    loss = criterion(pred, target)

    #6) figure out which weights caused us to miss
    loss.backward()

    # 7) send gradient of the received activation signal to the model behind
    # grad_a = remote_a.grad.copy().move(models[0].location)

    # 8) backpropagate on bottom model given this gradient
    # a.backward(grad_a)

    #9) change the weights
    for opt in optimizers:
        opt.step()

    #10) print our progress
    return loss.detach().get()

for i in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        images = images.send(alice)
        images = images.view(images.shape[0], -1)
        labels = labels.send(bob)
        
        loss = train(images, labels, models, optimizers)
        running_loss += loss

    else:
        print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to cifar100/cifar-100-python.tar.gz


0it [00:00, ?it/s]

Extracting cifar100/cifar-100-python.tar.gz to cifar100
Epoch 0 - Training loss: 3.9301202297210693
Epoch 1 - Training loss: 3.4655303955078125
Epoch 2 - Training loss: 3.251256227493286
Epoch 3 - Training loss: 3.0955700874328613
Epoch 4 - Training loss: 2.969331741333008
Epoch 5 - Training loss: 2.8563387393951416
Epoch 6 - Training loss: 2.751477003097534
Epoch 7 - Training loss: 2.6521451473236084
Epoch 8 - Training loss: 2.5549862384796143
Epoch 9 - Training loss: 2.466583728790283
