# Part X - Encrypted Learning

https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
import time

In [3]:
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice
james = sy.VirtualWorker(hook, id="james")  # <-- NEW: and alice

In [4]:
torch.tensor([1.]).fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)

(Wrapper)>AutogradTensor>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:35537720908 -> bob:44344244096]
	-> (Wrapper)>[PointerTensor | me:20482258853 -> alice:88038887721]
	*crypto provider: james*

In [5]:
data = torch.tensor([1.]).fix_precision().send(bob).share(bob, alice, crypto_provider=james, requires_grad=True)
data

(Wrapper)>AutogradTensor>[PointerTensor | me:94873097749 -> bob:87670374928]

In [6]:
data.get()

AutogradTensor>(Wrapper)>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:29078829087 -> bob:88158707378]
	-> (Wrapper)>[PointerTensor | me:60224767788 -> alice:21514300422]
	*crypto provider: james*

In [7]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.02
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [8]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x #F.log_softmax(x, dim=1)

In [10]:
ENCRYPTED = False

In [11]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        if not ENCRYPTED:
            model.send(data.location)
            target = target.get()
            target_onehot = torch.zeros(*target.shape, 10)
            target_onehot = target_onehot.scatter(1, target.view(-1, 1), 1)
            target = target.send(data.location)
            target_onehot = target_onehot.send(data.location)
        else:
            # Encrypt data TODO remote_get
            data = data.get().fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
            target = target.get()
            target_onehot = torch.zeros(*target.shape, 10)
            target_onehot = target_onehot.scatter(1, target.view(-1, 1), 1)
            target_onehot = target_onehot.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
        
        start_time = time.time()
        
        optimizer.zero_grad()
        
        output = model(data)
        #print(output.shape)
        loss = ((output - target_onehot)**2).sum()/output.shape[0]
        #print(loss.child.child.child.child.virtual_get())
        #loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        print(time.time() - start_time)
        
        if not ENCRYPTED:
            model.get()
        if batch_idx % 5 == 0:
            if not ENCRYPTED:
                loss = loss.get()# <-- NEW: get the loss back
            else:
                loss = loss.get().float_precision()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))
        if batch_idx > 9:
            return
            

The test function does not change!

In [12]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

### Launch the training !

In [13]:
%%time
model = Net().to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment
if ENCRYPTED:
    optimizer = optimizer.fix_precision() 

for epoch in range(1, args.epochs + 1):
    if ENCRYPTED:
        model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)
    train(args, model, device, federated_train_loader, optimizer, epoch)
    if ENCRYPTED:
        model = model.get().float_precision()
    test(args, model, device, test_loader)

0.008700132369995117
0.008021116256713867
0.008151054382324219
0.009250879287719727
0.00811910629272461
0.008960962295532227
0.008023977279663086
0.009498119354248047
0.008041858673095703
0.007991790771484375
0.008270978927612305

Test set: Average loss: -0.1972, Accuracy: 5010/10000 (50%)

0.0077588558197021484
0.009422779083251953
0.008527040481567383
0.009042024612426758
0.010618209838867188
0.008131027221679688
0.008008956909179688
0.00802302360534668
0.00804281234741211
0.008756875991821289
0.008182048797607422

Test set: Average loss: -0.2996, Accuracy: 6415/10000 (64%)

0.008003711700439453
0.008285760879516602
0.008233785629272461
0.008534908294677734
0.009061336517333984
0.00869131088256836
0.008514881134033203
0.00860285758972168
0.008605241775512695
0.008798837661743164
0.008642911911010742

Test set: Average loss: -0.3668, Accuracy: 7060/10000 (71%)

0.008746862411499023
0.009335994720458984
0.008994102478027344
0.009067058563232422
0.009556055068969727
0.009107828140258789

KeyboardInterrupt: 

Et voilà! Here you are, you have trained a model on remote data using Federated Learning!

## One Last Thing
I know there's a question you're dying to ask: **how long does it takes to do Federated Learning compared to normal PyTorch?**

The computation time is actually **less than twice the time** used for normal PyTorch execution! More precisely, it takes 1.9 times longer, which is very little compared to the features we were able to add.

## Conclusion

As you observe, we modified 10 lines of code to upgrade the official Pytorch example on MNIST to a real Federated Learning setting!

Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers and to perform federated averaging, to update the central model every `n` batches only, to reduce the number of messages we use to communicate between workers, etc. These are features we're working on to make Federated Learning ready for a production environment and we'll write about them as soon as they are released!

You should now be able to do Federated Learning by yourself! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways! 

### Star PySyft on GitHub

The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Pick our tutorials on GitHub!

We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.

- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)


### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! 

- [Join slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! If you want to start "one off" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.

- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)