# Part X - Encrypted Learning WITNESS

https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
import time

In [3]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.02
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [4]:
federated_train_loader = torch.utils.data.DataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x #F.log_softmax(x, dim=1)

In [10]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        
        target_onehot = torch.zeros(*target.shape, 10)
        target_onehot = target_onehot.scatter(1, target.view(-1, 1), 1)
        
        #start_time = time.time()

        optimizer.zero_grad()
        
        output = model(data)
        loss = ((output - target_onehot)**2).sum()/output.shape[0]
        loss.backward()
        optimizer.step()
        
        #print(time.time() - start_time)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))
            

The test function does not change!

In [11]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    first = True
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if first:
                print(model.fc3.weight[:4, :4])
                first = False
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

### Launch the training !

In [12]:
import cProfile

In [13]:
cp = cProfile.Profile()
cp.enable()

model = Net().to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, 2):#args.epochs + 1
    train(args, model, device, federated_train_loader, optimizer, epoch)
    #test(args, model, device, test_loader)
    
cp.disable()
cp.print_stats()

         4788713 function calls (4783754 primitive calls) in 14.979 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.087    0.087   14.977   14.977 <ipython-input-10-880310691395>:1(train)
        1    0.000    0.000    0.000    0.000 <ipython-input-13-0f37d116c0ec>:12(<module>)
        1    0.000    0.000    0.001    0.001 <ipython-input-13-0f37d116c0ec>:4(<module>)
        1    0.000    0.000    0.000    0.000 <ipython-input-13-0f37d116c0ec>:6(<module>)
        1    0.000    0.000   14.977   14.977 <ipython-input-13-0f37d116c0ec>:8(<module>)
        1    0.000    0.000    0.001    0.001 <ipython-input-5-22beaa110c32>:2(__init__)
      938    0.018    0.000    0.386    0.000 <ipython-input-5-22beaa110c32>:8(forward)
   120000    0.118    0.000    0.167    0.000 Image.py:2329(_check_size)
    60000    0.191    0.000    0.805    0.000 Image.py:2347(new)
    60000    0.229    0.000    1.548    0.000 Image.py:242

425014/424757    0.068    0.000    0.069    0.000 {built-in method builtins.len}
    60000    0.048    0.000    0.048    0.000 {built-in method builtins.max}
      939    0.001    0.000    0.043    0.000 {built-in method builtins.next}
       32    0.000    0.000    0.006    0.000 {built-in method builtins.print}
    60000    0.117    0.000    0.117    0.000 {built-in method from_buffer}
       15    0.000    0.000    0.000    0.000 {built-in method math.sqrt}
      938    0.006    0.000    0.006    0.000 {built-in method ones_like}
       64    0.000    0.000    0.000    0.000 {built-in method posix.getpid}
        1    0.001    0.001    0.001    0.001 {built-in method randperm}
     1876    0.027    0.000    0.027    0.000 {built-in method relu}
      938    0.143    0.000    0.143    0.000 {built-in method stack}
      938    0.010    0.000    0.010    0.000 {built-in method tensor}
     3752    0.005    0.000    0.005    0.000 {built-in method torch._C._get_tracing_state}
        1

Et voilà! Here you are, you have trained a model on remote data using Federated Learning!

## One Last Thing
I know there's a question you're dying to ask: **how long does it takes to do Federated Learning compared to normal PyTorch?**

The computation time is actually **less than twice the time** used for normal PyTorch execution! More precisely, it takes 1.9 times longer, which is very little compared to the features we were able to add.

## Conclusion

As you observe, we modified 10 lines of code to upgrade the official Pytorch example on MNIST to a real Federated Learning setting!

Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers and to perform federated averaging, to update the central model every `n` batches only, to reduce the number of messages we use to communicate between workers, etc. These are features we're working on to make Federated Learning ready for a production environment and we'll write about them as soon as they are released!

You should now be able to do Federated Learning by yourself! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways! 

### Star PySyft on GitHub

The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Pick our tutorials on GitHub!

We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.

- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)


### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! 

- [Join slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! If you want to start "one off" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.

- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)