In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import syft as sy
import torchvision.datasets as datasets
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.display import clear_output

## Constants

In [3]:
n_workers = 2

lr = 0.1
batch_size = 64

## Load dataset

In [4]:
mnist_train = datasets.MNIST(root='./data', train=True, download=True, 
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]))
mnist_test = datasets.MNIST(root='./data', train=False, download=True, 
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


In [5]:
n_samples_per_worker = len(mnist_train.targets) // n_workers

## Put data on different workers

### Declare the hook

In [6]:
hook = sy.TorchHook(torch)

### Declare the workers

In [7]:
workers = [sy.VirtualWorker(hook, id=f"bob{i}") for i in range(n_workers)]

### Move the data

In [8]:
federated_train_loader = sy.FederatedDataLoader(
    mnist_train.federate([w for w in workers]),
    batch_size=batch_size, shuffle=True)

In [9]:
test_loader = torch.utils.data.DataLoader(
    mnist_test,
    batch_size=batch_size, shuffle=True)

## Create model

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [11]:
loss_fn = torch.nn.NLLLoss()

## Train the model

In [15]:
model = Net()

lr = 0.1
opt = optim.SGD(params=model.parameters(), lr=lr)

In [19]:
n_iters_per_worker=10
n_epochs = 2

for i_epoch in range(n_epochs):
    loss_epoch = 0
    for i_batch, (X_ptr, Y_ptr) in enumerate(federated_train_loader):
        model = model.send(X_ptr.location)

        pred = model(X_ptr)
        loss = loss_fn(pred, Y_ptr)
        

        opt.zero_grad()
        loss.backward()
        opt.step()

        model = model.get()
        loss_epoch += loss.get().data.numpy()

        clear_output(wait=True)
        print(f"Epoch {i_epoch}")
        print(f"Batch {i_batch} / {len(mnist_train.data) // batch_size}: {loss_epoch / (i_batch + 1):.4f}")
        print(f"Location {X_ptr.location}")

Epoch 1
Batch 937 / 937: 0.0456
Location <VirtualWorker id:bob1 #objects:5>


## Test the model

In [21]:
model.eval()
correct = 0
for X, Y in test_loader:
    output = model(X)
    pred = output.argmax(1, keepdim=True)
    correct += pred.eq(Y.view_as(pred)).sum().item()

print('Accuracy: {}/{} ({:.0f}%)\n'.format(
    correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

Accuracy: 9848/10000 (98%)



## Clean everything (in case)

In [20]:
for i in range(n_workers):
    workers[i].clear_objects()