## Part 06 - Federated Learning on MNIST using a CNN
--- 

### Upgrade to Federated Learning in 10 Lines of PyTorch + PySyft

Federated Learning is a very exciting and upsurging Machine Learning technique that aims at building systems that learn on decentralized data. The idea is that the data remains in the hands of its producer (which is also known as the worker), which helps improving privacy and ownership, and the model is shared between workers. One immediate application is for example to predict the next word on your mobile phone when you write text: you don't want the data used for training — i.e. your text messages — to be sent to a central server.

The rise of Federated Learning is therefore tightly connected to the spread of data privacy awareness, and the GDPR in EU which enforces data protection since May 2018 has acted as a catalyst. To anticipate on regulation, large actors like Apple or Google have started investing massively in this technology, especially to protect the mobile users' privacy, but they have not made their tools available. At OpenMined, we believe that anyone willing to conduct a Machine Learning project should be able to implement privacy preserving tools with very little effort. We have built tools for encrypting data in a single line as mentioned in our blog post and we now release our Federated Learning framework which leverage the new PyTorch 1.0 version to provide an intuitive interface to building secure and scalable models.

In this tutorial, we'll use directly the canonical example of training a CNN on MNIST using PyTorch and show how simple it is to implement Federated Learning with it using our PySyft library. We will go through each part of the example and underline the code which is changed.

### Imports and model specifications

First we make the official imports



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

And than those specific to PySyft. In particular we define remote workers alice and bob.

In [3]:
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice

We define the setting of the learning task

In [12]:
class Arguments(object):
    def __init__(self, epochs):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = epochs
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments(10)

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

## kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [18]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True)

In [19]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [20]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

In [21]:

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [22]:
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")


Test set: Average loss: 0.2075, Accuracy: 9326/10000 (93%)


Test set: Average loss: 0.1123, Accuracy: 9673/10000 (97%)


Test set: Average loss: 0.0692, Accuracy: 9798/10000 (98%)


Test set: Average loss: 0.0689, Accuracy: 9770/10000 (98%)


Test set: Average loss: 0.0520, Accuracy: 9844/10000 (98%)




Test set: Average loss: 0.0474, Accuracy: 9847/10000 (98%)


Test set: Average loss: 0.0435, Accuracy: 9868/10000 (99%)


Test set: Average loss: 0.0430, Accuracy: 9860/10000 (99%)


Test set: Average loss: 0.0414, Accuracy: 9861/10000 (99%)


Test set: Average loss: 0.0425, Accuracy: 9871/10000 (99%)

CPU times: user 14min 41s, sys: 5.24 s, total: 14min 46s
Wall time: 14min 45s
