In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import syft as sy
import torchvision.datasets as datasets
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import math

from IPython.display import clear_output

## Constants

In [27]:
n_workers = 10

lr = 0.1
batch_size = 16

## Load dataset

In [28]:
def transform(X_mnist):
    return X_mnist.view(-1, 28, 28).float() / 255

In [29]:
def batch_list(l, batch_size):
    return [l[i*batch_size:(i+1)*batch_size] for i in range(int(math.ceil(len(l) / batch_size)))]

In [30]:
mnist_train = datasets.MNIST(root='./data', train=True, download=True, 
                             transform=None)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, 
                             transform=transforms.Compose([
                                 transforms.ToTensor()                             
                             ]))

In [31]:
test_loader = torch.utils.data.DataLoader(
    mnist_test,
    batch_size=batch_size, shuffle=True)

In [32]:
n_samples_per_worker = len(mnist_train.targets) // n_workers
n_samples_train = len(mnist_train.targets)

## Put data on different workers

### Declare the hook

In [33]:
hook = sy.TorchHook(torch)



### Declare the workers

In [34]:
workers = [sy.VirtualWorker(hook, id=f"bob{i}") for i in range(n_workers)]

### Move the data

In [35]:
X_train_ptr_worker = [transform(mnist_train.data[n_samples_per_worker*i:n_samples_per_worker*(i+1)]).send(workers[i]) for i in range(n_workers)]
Y_train_ptr_worker = [mnist_train.targets[n_samples_per_worker*i:n_samples_per_worker*(i+1)].send(workers[i]) for i in range(n_workers)]

In [36]:
X_train_ptr_worker[0].shape

torch.Size([20000, 28, 28])

## Create model

In [37]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [38]:
loss_fn = torch.nn.NLLLoss()

## Train the model

In [39]:
lr = 0.1

In [40]:
model = Net()
model_worker = [model.copy() for _ in range(n_workers)] # Important: starting from the same initialization
opt_worker = [optim.SGD(params=model_worker[i].parameters(), lr=lr) for i in range(n_workers)]

In [41]:
n_epochs = 1

for i_epoch in range(n_epochs):
    loss_epoch = 0
    sample_indices = list(np.random.permutation(range(n_samples_per_worker)))
    list_batch_indices = batch_list(sample_indices, batch_size)
    for i_batch, batch_indices in enumerate(list_batch_indices):
        # Train each worker individually
        for i_worker in range(n_workers):
            batch_indices_tensor = torch.tensor(batch_indices).send(workers[i_worker])
            X_ptr = X_train_ptr_worker[i_worker][batch_indices, :, :]
            Y_ptr = Y_train_ptr_worker[i_worker].gather(0, batch_indices_tensor)
            model_worker[i_worker] = model_worker[i_worker].send(X_ptr.location)
            
            pred = model_worker[i_worker](X_ptr)
            loss = loss_fn(pred, Y_ptr)

            opt_worker[i_worker].zero_grad()
            loss.backward()
            opt_worker[i_worker].step()

            model_worker[i_worker] = model_worker[i_worker].get()
            loss_epoch += loss.get().data.numpy()

            clear_output(wait=True)
            print(f"Epoch {i_epoch}")
            print(f"Batch {i_batch} / {len(mnist_train.data) // n_workers // batch_size}: {loss_epoch / n_workers / (i_batch + 1):.4f}")
            print(f"Location {X_ptr.location}")
            del batch_indices_tensor
        
        with torch.no_grad():
            for params in zip(*(model_worker[i].parameters() for i in range(n_workers))):
                new_param = torch.mean(torch.stack(params), axis=0)
                for param in params:
                    param.set_(new_param)

Epoch 0
Batch 173 / 1250: 0.8274
Location <VirtualWorker id:bob2 #objects:6>


KeyboardInterrupt: 

## Test the model

In [244]:
model_worker[0].eval()
correct = 0
for X, Y in test_loader:
    output = model_worker[0](X)
    pred = output.argmax(1, keepdim=True)
    correct += pred.eq(Y.view_as(pred)).sum().item()

print('Accuracy: {}/{} ({:.0f}%)\n'.format(
    correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

Accuracy: 8950/10000 (90%)



## Clean everything (in case)

In [253]:
for i in range(n_workers):
    workers[i].clear_objects()

In [254]:
X_train_ptr_worker = [transform(mnist_train.data[n_samples_per_worker*i:n_samples_per_worker*(i+1)]).send(workers[i]) for i in range(n_workers)]
Y_train_ptr_worker = [mnist_train.targets[n_samples_per_worker*i:n_samples_per_worker*(i+1)].send(workers[i]) for i in range(n_workers)]

In [255]:
model_worker = [Net() for _ in range(n_workers)]
opt_worker = [optim.SGD(params=model_worker[i].parameters(), lr=lr) for i in range(n_workers)]