# Federated learning with pysyft on MNIST data
in this notebook, we are going to cover training a neural network on the MNIST dataset while implementing the federated learning approach with the pysyft library.

## prerequisites
- Familiarity with pysyft [check this notebook on an introductory guide to pysyft](https://jovian.ai/tifeasypeasy/introducing-privacy-preserving-tool)

## Federated learning workflow

 - Edge devices  receives a copy of a global model from a central server.
 - The model is being trained locally on the data residing on the edge devices 
 - The global model weights are updated during training on each worker
 - A local copy is sent back to the central server 
 - The server receives various updated model and aggregate the updates thereby improving the global model and also preserving privacy of data in which it was being trained on.


![image](https://miro.medium.com/max/640/0*yI_rIRNAFTYDwFtK.png)

Figure1. Federated learning

Use cases
- **Next word prediction**: Federated learning is used to improve word prediction in mobile devices without uploading users data to the cloud for training. Google gboard implements FL by using on-device data to train and improve the gboard next word prediction model for it's users, watch this video to better understand how google uses federated learning at scale. You can also read this online comic from Google AI to get a better grasp of federated learning. 

- **Voice recognition**: FL is also used in voice recognition technologies, an example is in apple Siri and just recently Google introduces the audio recognition technology on Google assistant to train and better improve users experience with the Google assistant. Watch a demonstration of FL on audio recording for speech systems here. 

In every of these use cases, the data doesn't leave the edge devices thereby keeping the data private, safe and secure while still improving these technologies and making products smarter over time.

# Setup

## How to run the code 
Running this notebook is the same as running the 'introducing-privacy-preserving-tool' notebook. You can either run it online on the jovian platform [via this link](https://jovian.ai/tifeasypeasy/federated-learning-on-mnist) without having to install anything or locally on your machine 


In [1]:
# !pip install jovian --upgrade --quiet
# uncomment the above command if youre running on binder

In [2]:
# this installs the syft library, comment the below command if you were able to install syft locally
!pip install syft

Collecting syft
[?25l  Downloading https://files.pythonhosted.org/packages/1c/73/891ba1dca7e0ba77be211c36688f083184d8c9d5901b8cd59cbf867052f3/syft-0.2.9-py3-none-any.whl (433kB)
[K     |████████████████████████████████| 440kB 3.3MB/s eta 0:00:01
[?25hCollecting lz4~=3.0.2 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/65/38/dacc3cbb33a9ded9e2e57f48707e8842f1080997901578ebddaa0e031646/lz4-3.0.2-cp37-cp37m-manylinux2010_x86_64.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 26.0MB/s eta 0:00:01
[?25hCollecting dill~=0.3.1 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/e2/96/518a8ea959a734b70d2e95fef98bcbfdc7adad1c1e5f5dd9148c835205a5/dill-0.3.2.zip (177kB)
[K     |████████████████████████████████| 184kB 20.2MB/s eta 0:00:01
[?25hCollecting phe~=1.4.0 (from syft)
  Downloading https://files.pythonhosted.org/packages/32/0e/568e97b014eb14e794a1258a341361e9da351dc6240c63b89e1541e3341c/phe-1.4.0.tar.gz
Collecting torch~=1.4.

Collecting openmined.threepio==0.2.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/0a/38/df6367693c7f3808f076cd8c2647c434a04adda2bbb2435dadefe7258fd4/openmined.threepio-0.2.0.tar.gz (73kB)
[K     |████████████████████████████████| 81kB 10.2MB/s eta 0:00:01
[?25hCollecting shaloop==0.2.1-alpha.11 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/7b/8e/6c4493280d55199161c2eea896327c740195cf16cc74c5393c08eababc83/shaloop-0.2.1_alpha.11-py3-none-manylinux1_x86_64.whl (126kB)
[K     |████████████████████████████████| 133kB 17.1MB/s eta 0:00:01
[?25hCollecting websocket-client~=0.57.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)
[K     |████████████████████████████████| 204kB 25.4MB/s eta 0:00:01
[?25hCollecting aiortc==0.9.28 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packa

Collecting python-engineio>=3.13.0 (from python-socketio>=4.3.0->flask-socketio~=4.2.1->syft)
[?25l  Downloading https://files.pythonhosted.org/packages/4a/b0/602e549c6d735eb487f186b35e0b82e61c89459f57d1c24d5c7be6f56d05/python_engineio-3.13.2-py2.py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 20.1MB/s eta 0:00:01
Collecting netifaces (from aioice<0.7.0,>=0.6.17->aiortc==0.9.28->syft)
  Downloading https://files.pythonhosted.org/packages/0d/18/fd6e9c71a35b67a73160ec80a49da63d1eed2d2055054cc2995714949132/netifaces-0.10.9.tar.gz
Building wheels for collected packages: dill, phe, tornado, psutil, openmined.threepio, netifaces
  Building wheel for dill (setup.py) ... [?25ldone
[?25h  Created wheel for dill: filename=dill-0.3.2-cp37-none-any.whl size=78912 sha256=6e8290b4ed0494e004c60bcf8c1455487f888bf1da77aeca4be19d965957a16a
  Stored in directory: /home/jovyan/.cache/pip/wheels/27/4b/a2/34ccdcc2f158742cfe9650675560dea85f78c3f4628f7daad0
  Building wheel for phe 

## Imports and initializing hook
Lets import torch, torchvision, and other modules 

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [15]:
import syft as sy  # import the syft library
hook = sy.TorchHook(torch)  # attach the pytorch hook
joe = sy.VirtualWorker(hook, id="joe")  #  remote worker joe
jane = sy.VirtualWorker(hook, id="jane")  #  remote worker  jane



## Load in the MNIST dataset 
Let's load in the data and transform it into a federated dataset by implementing the federate() method.

In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

train_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=True, download=True, transform=transform)
test_set = datasets.MNIST(
    "~/.pytorch/MNIST_data/", train=False, download=True, transform=transform)

federated_train_loader = sy.FederatedDataLoader(
    train_set.federate((joe, jane)), batch_size=64, shuffle=True) # the federate() method splits the data within the workers

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=64, shuffle=True)


## Define the network architecture
The network architecture would remain the same just as the example tutorial from pytorch with an input of 784-dim tensor of pixel values for each image, and producing a tensor of length 10  which indicates the class scores for an input image

In [17]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
model = Net()
print(model)
optimizer = optim.SGD(model.parameters(), lr=0.01)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


## Train the network 
Looking at the training process of this distributed approach, the data seem to be on a remote machine so therefore we would use the location attribute to get the location and send our model to that location where the data is present, we would then get back the improved model using the get() method and calculate the loss. 


In [18]:
n_epoch = 10 
for epoch in range(n_epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader):
        model.send(data.location) # send the model to the client device where the data is present
        optimizer.zero_grad()         # training the model
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # get back the improved model
        if batch_idx % 100 == 0: 
            loss = loss.get() # get back the loss
            print('Training Epoch: {:2d} [{:5d}/{:5d} ({:3.0f}%)]\tLoss: {:.6f}'.format(
                epoch+1, batch_idx * 64,
                len(federated_train_loader) * 64,
                100. * batch_idx / len(federated_train_loader), loss.item()))



## Testing the trained model
Remember the test dataset remains unchanged as it is on our local machine compared to the train dataset which we have splitted between two virtual workers.

In [19]:
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        test_loss += F.nll_loss(
            output, target, reduction='sum').item()
        pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss,correct,
    len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.1762, Accuracy: 9475/10000 (95%)



As you can see, we achieved an accuracy of 95% which is pretty good for a federated sytem as this tutorial has demonstrated. 