# Part 10: Federated Learning with Encrypted Gradient Aggregation

In the last few sections, we've been learning about encrypted computation by building several simple programs. In this section, we're going to return to the [Federated Learning Demo of Part 4](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%204%20-%20Federated%20Learning%20via%20Trusted%20Aggregator.ipynb), where we had a "trusted aggregator" who was responsible for averaging the model updates from multiple workers.

We will now use our new tools for encrypted computation to remove this trusted aggregator because it is less than ideal as it assumes that we can find someone trustworthy enough to have access to this sensitive information. This is not always the case.

Thus, in this notebook, we will show how one can use Secure Multi-Party Computation to perform secure aggregation such that we don't need a "trusted aggregator".

Authors:
- Theo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)
- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)

# Section 1: Normal Federated Learning

First, here is some code which performs classic federated learning on the Boston Housing Dataset. This section of code is broken into several sections.

### Setting Up

In [1]:
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

## Loading the Dataset

In [2]:
with open('../data/BostonHousing/boston_housing.pickle','rb') as f:
    ((X, y), (X_test, y_test)) = pickle.load(f)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()
# preprocessing
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0. # the feature at column 3 is binary,
dev[:, 3] = 1.  # so we don't standardize it
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

## Neural Network Structure

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc3 = nn.Linear(24, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr)

## Hooking PyTorch

In [4]:
import syft as sy

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
james = sy.VirtualWorker(hook, id="james")

workers = [bob, alice]
n_workers = len(workers)

**Send data to the workers** <br>
Usually they would already have it, this is just for demo purposes that we send it manually

In [5]:
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(workers[batch_idx % len(workers)])
    target = target.send(workers[batch_idx % len(workers)])
    train_distributed_dataset.append((data, target))

## Training Function

In [6]:
def train(epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # update the model
        pred = model(data)
        loss = F.mse_loss(pred.view(-1), target)
        loss.backward()
        optimizer.step()
        model.get()
            
        if batch_idx % args.log_interval == 0:
            loss = loss.get()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * data.shape[0], len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))
        


## Testing Function

In [7]:
def test():
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        output = model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

## Training the Model

In [8]:
import time

In [9]:
t = time.time()

for epoch in range(1, args.epochs + 1):
    train(epoch)

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Total 7.57 s


## Calculating Performance

In [10]:
test()


Test set: Average loss: 20.7677



# Section 2: Adding Encrypted Aggregation

Now we're going to slightly modify this example to aggregate gradients using encryption. First, let's re-process our data and initialize a model for bob and alice.

In [11]:
train_distributed_dataset = {worker.id: [] for worker in workers}

for batch_idx, (data, target) in enumerate(train_loader):
    worker_idx = batch_idx % len(workers)
    worker = workers[worker_idx]
    data = data.send(worker)
    target = target.send(worker)
    train_distributed_dataset[worker.id].append((data, target))

We need two main functionalities: the first one should at the beginning of an epoch send the same version of the model (the baseline) to all workers.

The second one should, after each worker has done one epoch of training, perform the secure aggregation

## TODO

And now that we know each step, we can put it all together into one training loop!

In [12]:
local_model = Net()
models = [Net() for i in range(n_workers)]
optimizers = [optim.SGD(models[i].parameters(), lr=args.lr) for i in range(n_workers)]

def send_new_models(local_model, models):
    with torch.no_grad():
        for remote_model in models:
            for new_param, remote_param in zip(local_model.parameters(), remote_model.parameters()):
                worker = remote_param.location
                remote_value = new_param.send(worker)
                # Try this and if not do x_ptr * 0 + remote_value
                remote_param.set_(remote_value)

            
def secure_aggregation(local_model, models):
    with torch.no_grad():
        for local_param, *remote_params in zip(*([local_model.parameters()] + [model.parameters() for model in models])):
            param_stack = remote_params[0].copy().fix_prec().share(alice, bob, crypto_provider=james).get()
            for remote_param in remote_params[1:]:
                param_stack += remote_param.copy().fix_prec().share(alice, bob, crypto_provider=james).get()
            param_stack /= len(remote_params)
            param_stack = param_stack.get().float_prec()
            local_param.set_(param_stack)


def train():
    # Initial sending of the models
    for model, optimizer, worker in zip(models, optimizers, workers):
        model.send(worker)
        
    for epoch in range(args.epochs):
        print(f'Epoch {epoch}')
        
        # 1. Send new version of the model
        send_new_models(local_model, models)
        
        # 2. Train remotely the models
        for i, worker in enumerate(workers):
            model = models[i]
            optimizer = optimizers[i]
            dataloader = train_distributed_dataset[worker.id]
            for (data, target) in dataloader:
                optimizer.zero_grad()
                pred = model(data)
                loss = F.mse_loss(pred.view(-1), target)
                loss.backward()
                optimizer.step()
        
        # 3. Secure aggregation of the updated models
        secure_aggregation(local_model, models)
        
    
    # optional:
    for model in models:
        model.get()

In [13]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [14]:

import time
t = time.time()


train()
test()

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Test set: Average loss: 17.8664

Total 4.83 s


# Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

### Star PySyft on Github

The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft Github Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for github issues marked "good first issue".

- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)
- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

[OpenMined's Open Collective Page](https://opencollective.com/openmined)