In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import syft as sy
import torchvision.datasets as datasets
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import math

from IPython.display import clear_output

## Constants

In [2]:
n_workers = 100
n_workers_epoch = int(n_workers * 0.2)

lr = 0.1
batch_size = 16
test_batch_size = 100

## Load dataset

In [3]:
def transform(X_mnist):
    return X_mnist.view(-1, 28, 28).float() / 255

In [4]:
def batch_list(l, batch_size):
    return [l[i*batch_size:(i+1)*batch_size] for i in range(int(math.ceil(len(l) / batch_size)))]

In [5]:
mnist_train = datasets.MNIST(root='./data', train=True, download=True, 
                             transform=None)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, 
                             transform=transforms.Compose([
                                 transforms.ToTensor()                             
                             ]))

In [6]:
test_loader = torch.utils.data.DataLoader(
    mnist_test,
    batch_size=test_batch_size, shuffle=True)

In [None]:
train_loader = torch.utils.data.DataLoader(
    mnist_train,
    batch_size=test_batch_size, shuffle=True)

In [7]:
n_samples_per_worker = len(mnist_train.targets) // n_workers
n_samples_train = len(mnist_train.targets)

## Put data on different workers

### Declare the hook

In [8]:
hook = sy.TorchHook(torch)

### Declare the workers

In [9]:
workers = [sy.VirtualWorker(hook, id=f"bob{i}") for i in range(n_workers)]

### Move the data

In [10]:
X_train_ptr_worker = [transform(mnist_train.data[n_samples_per_worker*i:n_samples_per_worker*(i+1)]).send(workers[i]) for i in range(n_workers)]
Y_train_ptr_worker = [mnist_train.targets[n_samples_per_worker*i:n_samples_per_worker*(i+1)].send(workers[i]) for i in range(n_workers)]

In [11]:
X_train_ptr_worker[0].shape

torch.Size([600, 28, 28])

## Create model

In [12]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [13]:
loss_fn = torch.nn.NLLLoss()

## Train the model

In [14]:
lr = 0.1

In [44]:
model = Net()
model_worker = [model.copy() for _ in range(n_workers)] # Important: starting from the same initialization
opt_worker = [optim.SGD(params=model_worker[i].parameters(), lr=lr) for i in range(n_workers)]

list_loss_worker = [list() for _ in range(n_workers)]

In [41]:
n_epochs = 20
n_local_epochs = 1

for i_epoch in range(n_epochs):
    sample_indices = list(np.random.permutation(range(n_samples_per_worker)))
    
    list_workers_epoch = np.random.choice(range(n_workers), n_workers_epoch, replace=False)
    for i_worker, worker_index in enumerate(list_workers_epoch):
        loss_worker = 0
        i_iter = 0
        
        # Update client
        for i_local_epoch in range(n_local_epochs):  # n_local_epochs=E
            list_batch_indices = batch_list(sample_indices, batch_size)
            for i_batch, batch_indices in enumerate(list_batch_indices):
                batch_indices_tensor = torch.tensor(batch_indices).send(workers[worker_index])
                X_ptr = X_train_ptr_worker[worker_index][batch_indices, :, :]
                Y_ptr = Y_train_ptr_worker[worker_index].gather(0, batch_indices_tensor)
                model_worker[worker_index] = model_worker[worker_index].send(X_ptr.location)

                pred = model_worker[worker_index](X_ptr)
                loss = loss_fn(pred, Y_ptr)

                opt_worker[worker_index].zero_grad()
                loss.backward()
                opt_worker[worker_index].step()

                model_worker[worker_index] = model_worker[worker_index].get()
                loss_worker += loss.get().data.numpy()
                i_iter += 1

                clear_output(wait=True)
                print(f"Epoch {i_epoch + 1}/{n_epochs}")
                print(f"Worker {i_worker+1}/{n_workers_epoch}")
                print(f"Batch {i_batch} / {len(mnist_train.data) // n_workers // batch_size}: {loss_worker / (i_iter):.4f}")
                print(f"Location {X_ptr.location}")
                del batch_indices_tensor

    # Update global model
    with torch.no_grad():
        for params in zip(*(model_worker[i].parameters() for i in range(n_workers))):
            new_param = torch.mean(torch.stack(params), axis=0)
            for param in params:
                param.set_(new_param)
                
    # Test global model
    

Epoch 10/10
Worker 20/20
Batch 37 / 37: 0.4889
Location <VirtualWorker id:bob78 #objects:6>


## Test the model

In [42]:
model_worker[0].eval()
correct = 0
for X, Y in test_loader:
    output = model_worker[0](X)
    pred = output.argmax(1, keepdim=True)
    correct += pred.eq(Y.view_as(pred)).sum().item()

print('Accuracy: {}/{} ({:.0f}%)\n'.format(
    correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

Accuracy: 9038/10000 (90%)



## Clean everything (in case)

In [34]:
for i in range(n_workers):
    workers[i].clear_objects()

In [35]:
X_train_ptr_worker = [transform(mnist_train.data[n_samples_per_worker*i:n_samples_per_worker*(i+1)]).send(workers[i]) for i in range(n_workers)]
Y_train_ptr_worker = [mnist_train.targets[n_samples_per_worker*i:n_samples_per_worker*(i+1)].send(workers[i]) for i in range(n_workers)]

In [36]:
model_worker = [Net() for _ in range(n_workers)]
opt_worker = [optim.SGD(params=model_worker[i].parameters(), lr=lr) for i in range(n_workers)]