# Final Project: Federated Learning with Encrypted Gradient Aggregation

This is my implementation of FL, the code is mainly coming from Tutorial 10 on OpenMinded slack website: https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb

In [88]:
import torch as th
import syft as sy
hook=sy.TorchHook(th)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.001
        self.test_batch_size = 1
        self.batch_size = 1
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

data=th.tensor([[1.,1],[0,1],[1,0],[0,0]],requires_grad=True)
target=th.tensor([[1.],[1],[0],[0]],requires_grad=True)


W0621 14:05:08.907468 56792 hook.py:96] Torch was already hooked... skipping hooking process


Create a dataloader

In [95]:
train = TensorDataset(data, target)
train_loader = DataLoader(train, batch_size=1, shuffle=True)


In [101]:
train_loader.dataset.tensors

(tensor([[1., 1.],
         [0., 1.],
         [1., 0.],
         [0., 0.]], requires_grad=True), tensor([[1.],
         [1.],
         [0.],
         [0.]], requires_grad=True))

just, checking what it's in

In [104]:
for batch_idx, (data,target) in enumerate(train_loader):
    print(target)

tensor([[1.]], grad_fn=<StackBackward>)
tensor([[1.]], grad_fn=<StackBackward>)
tensor([[0.]], grad_fn=<StackBackward>)
tensor([[0.]], grad_fn=<StackBackward>)


## Create the model and the virtual workers

In [90]:
from torch import nn,optim
#model
model=nn.Linear(2,1)
#optimizer used
optimizer = optim.SGD(model.parameters(), lr=args.lr)

In [91]:
#virtual workers
bo=sy.VirtualWorker(hook,id='bo')
al=sy.VirtualWorker(hook,id='al')
jo=sy.VirtualWorker(hook,id='jo')
compute_nodes = [bo, al, jo]

In [355]:
#create a distributed dataset, send the data to each worker
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    train_distributed_dataset.append((data, target))

In [108]:
train_distributed_dataset

[((Wrapper)>[PointerTensor | me:11617755689 -> bo:90933393668],
  (Wrapper)>[PointerTensor | me:98177457693 -> bo:44289883049]),
 ((Wrapper)>[PointerTensor | me:45272425748 -> al:89313412076],
  (Wrapper)>[PointerTensor | me:84225090910 -> al:79352695087]),
 ((Wrapper)>[PointerTensor | me:4333230265 -> jo:10475716448],
  (Wrapper)>[PointerTensor | me:78673644835 -> jo:7826465726]),
 ((Wrapper)>[PointerTensor | me:31811298497 -> bo:9594562841],
  (Wrapper)>[PointerTensor | me:96610485666 -> bo:56670613824])]

In [109]:
def train(epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # update the model
        pred = model(data)
        loss = F.mse_loss(pred.view(-1), target)
        loss.backward()
        optimizer.step()
        model.get()
            
        if batch_idx % args.log_interval == 0:
            loss = loss.get()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * data.shape[0], len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))

In [110]:
def test():
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        #data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))



In [111]:
import time

In [112]:
t = time.time()

for epoch in range(1):
    train(epoch)

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Total 0.07 s


  response = eval(cmd)(*args, **kwargs)


# Adding encrypted aggregation

In [386]:
remote_dataset = (list(),list(), list())

train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

    
#updatin the parameters for each worker
def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    loss.backward()
    optimizer.step()
    loss = loss.get().item()
    return model, loss

bo_model = model.copy()
al_model = model.copy()
jo_model = model.copy()

bo_optimizer = optim.SGD(bo_model.parameters(), lr=args.lr)
al_optimizer = optim.SGD(al_model.parameters(), lr=args.lr)
jo_optimizer = optim.SGD(jo_model.parameters(), lr=args.lr)

models = [bo_model, al_model, jo_model]
params = [list(bo_model.parameters()), list(al_model.parameters()), list(jo_model.parameters())]
optimizers = [bo_optimizer, al_optimizer, jo_optimizer]

In [390]:
def train(epoch):
    for data_index in range(len(remote_dataset[0])-1):
        # update remote models
        tot_loss=[]
        for remote_index in range(len(compute_nodes)):
            
            data, target = remote_dataset[remote_index][data_index]
            #print("loss of " + compute_nodes[remote_index].id)
            models[remote_index], loss = update(data, target, models[remote_index], optimizers[remote_index])
            #print(loss)
            tot_loss.append(loss)
        print('avg loss', sum(tot_loss)/3)
                    

    #calculate new aggregate parameters
        new_params = list()
        for param_i in range(len(params[0])):
            spdz_params = list()
            # calculate a spread parameters shared between the three workers
            for remote_index in range(len(compute_nodes)):
                spdz_params.append((params[remote_index][param_i]+0).fix_precision().share(bo, al, jo).get())
            
            
            
            # get the sum encripted params from workers and convert them back to float and calculate average
            new_param = (spdz_params[0] + spdz_params[1] + spdz_params[2]).get().float_precision()/3
            new_params.append(new_param)
            
            #print("new params list", new_params)

            # zero the params
        with torch.no_grad():
            for model in params:
                for param in model:
                    param *= 0
                

            for model in models:
                model.get()

            #set the new parameters
            for remote_index in range(len(compute_nodes)):
                for param_index in range(len(params[remote_index])):
                    params[remote_index][param_index].set_(new_params[param_index])
                    
        

In [391]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

In [392]:
t = time.time()

for epoch in range(1, args.epochs + 1):
    print(epoch)
    train(epoch)
    #test() don't have a testloader in this example

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

1
avg loss 0.6913629804427425
2
avg loss 0.6891131357600292
3
avg loss 0.6869764604295293
4
avg loss 0.6847909204661846
5
avg loss 0.6826648395508528
6
avg loss 0.6804896506170431
7
avg loss 0.6783744376152754
8
avg loss 0.6762095478673776
9
avg loss 0.6741050584241748
10
avg loss 0.6719505063568553
Total 0.66 s
