# Parte 8: Aprendizaje federado con agregación de gradiente cifrada

En el anterior ejemplo, teníamos un "agregador de confianza" que era responsable de promediar las actualizaciones del modelo de varios trabajadores. Esto no es ideal porque asume que podemos encontrar a alguien lo suficientemente confiable para tener acceso a esta información confidencial. Este no es siempre el caso.

Por lo tanto, en este ejemplo, usaremos cómo se puede usar SMPC para realizar la agregación de manera que no necesitemos un "agregador de confianza".



# Sección 1: Aprendizaje federado normal

Primero, mostraré un código que realiza el aprendizaje federado normal en el conjunto de datos de viviendas de Boston. Esta sección de código se divide en varias secciones.

### Setting Up

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader

print(torch.__version__)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                    help='input batch size for training (default: 8)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                    help='input batch size for testing (default: 8)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001)')
parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
                    help='SGD momentum (default: 0.0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args([])

torch.manual_seed(args.seed)
kwargs = {}

0.3.1


# Cargando el Dataset

In [2]:
import pickle
f = open('../other/data/boston_housing.pickle','rb')
((X, y), (X_test, y_test)) = pickle.load(f)
f.close()

X = torch.from_numpy(X).type(torch.FloatTensor)
y = torch.from_numpy(y).type(torch.FloatTensor)
X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
y_test = torch.from_numpy(y_test).type(torch.FloatTensor)
# preprocessing
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0. # the feature at column 3 is binary,
dev[:, 3] = 1.  # so I'd rather not standardize it
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

# Estructura de la red neuronal

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc3 = nn.Linear(24, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()


optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

# Hookinfg PyTorch

In [4]:
import syft
import syft as sy
from syft.core import utils
import torch
import torch.nn.functional as F
import json
import random
from syft.core.frameworks.torch import utils as torch_utils
from torch.autograd import Variable
hook = sy.TorchHook(verbose=False)
me = hook.local_worker
bob = sy.VirtualWorker(id="bob",hook=hook, is_client_worker=False)
alice = sy.VirtualWorker(id="alice",hook=hook, is_client_worker=False)
me.is_client_worker = False

compute_nodes = [bob, alice]

bob.add_workers([alice])
alice.add_workers([bob])


**Enviar los datos a cada trabajador** <br>

In [5]:
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = Variable(data)
    target = Variable(target.float())
    data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target.send(compute_nodes[batch_idx % len(compute_nodes)])
    train_distributed_dataset.append((data, target))

## Training Function

In [6]:
def train(epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
            
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # update the model
        pred = model(data)
        loss = F.mse_loss(pred, target.float())
        loss.backward()
        model.get()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            loss.get()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,
                100. * batch_idx / len(train_loader), loss.data[0]))
        


# Testing Function

In [7]:
def test():
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.mse_loss(output, target.float(), size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))


# Entrenando el modelo

In [8]:
import time

In [9]:
%%time
t = time.time()
args.epochs = 10
torch.encode_timer = 0
torch.handle_call_timer = 0
torch.execute_call_timer = 0

for epoch in range(1, args.epochs + 1):
    train(epoch)

    
total_time = time.time() - t
print('Encoding', round(torch.encode_timer, 2), 's', round(torch.encode_timer/total_time*100, 2), '%')
print('Handling', round(torch.handle_call_timer, 2), 's',  round(torch.handle_call_timer/total_time*100, 2), '%')
print('Execute call', round(torch.execute_call_timer, 2), 's',  round(torch.execute_call_timer/total_time*100, 2), '%')
print('Total', round(total_time, 2), 's')

Encoding 0 s 0.0 %
Handling 0 s 0.0 %
Execute call 0 s 0.0 %
Total 19.16 s
CPU times: user 18.5 s, sys: 387 ms, total: 18.9 s
Wall time: 19.2 s


# Rendimiento

In [10]:
test()


Test set: Average loss: 20.7802



# Seccion 2: Añadiendo Encrypted Aggregation

Ahora vamos a modificar este ejemplo ligeramente para agregar gradientes usando cifrado. La pieza principal que es diferente es en realidad 1 o 2 líneas de código en la función train (). Por el momento, vamos a volver a procesar nuestros datos e inicializar un modelo para bob y alice.

In [11]:
remote_dataset = (list(),list())

for batch_idx, (data,target) in enumerate(train_loader):
    data = Variable(data)
    target = Variable(target.float())
    data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred, target.float())
    loss.backward()
    optimizer.step()
    return model

bobs_model = Net()
alices_model = Net()

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr, momentum=args.momentum)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr, momentum=args.momentum)

models = [bobs_model, alices_model]
params = [list(bobs_model.parameters()), list(alices_model.parameters())]
optimizers = [bobs_optimizer, alices_optimizer]


## Construyendo la lógica de entrenamiento


### Part A: Train:

In [12]:
# this is selecting which batch to train on
data_index = 0


# update remote models
# we could iterate this multiple times before proceeding, but we're only iterating once per worker here
for remote_index in range(len(compute_nodes)):
    data, target = remote_dataset[remote_index][data_index]
    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])


### Part B: Encrypted Aggregation

In [13]:
# create a list where we'll deposit our encrypted model average
new_params = list()

In [14]:
# iterate through each parameter
for param_i in range(len(params[0])):

    # for each worker
    spdz_params = list()
    for remote_index in range(len(compute_nodes)):
        
        # select the identical parameter from each worker and copy it
        copy_of_parameter = params[remote_index][param_i].data+0
        
        # since SMPC can only work with integers (not floats), we need
        # to use Integers to store decimal information. In other words,
        # we need to use "Fixed Precision" encoding.
        # fix it's precision (read more about Fixed Precision encodings)
        fixed_precision_param = copy_of_parameter.fix_precision()
        
        # now we encrypt it on the remote machine. Note that 
        # fixed_precision_param is ALREADY a pointer. Thus, when
        # we call share, it actually encrpyts the data that the
        # data is pointing TO. This returns a POINTER to  the 
        # MPC Shared object, which we need to fetch.
        encrypted_param = fixed_precision_param.share(bob, alice)
        
        # now we fetch the pointer to the MPC shared value
        param = encrypted_param.get()
        
        # save the parameter so we can average it with the same parameter
        # from the other workers
        spdz_params.append(param)

    # average params from multiple workers, fetch them to the local machine
    # decrypt and decode (from fixed precision) back into a floaing point number
    new_param = (spdz_params[0] + spdz_params[1]).get().decode()/2
    
    # save the new averaged parameter
    new_params.append(new_param)


### Part C: Cleanup

In [15]:
for model in params:
    for param in model:
        param.data *= 0

for model in models:
    model.get()

for remote_index in range(len(compute_nodes)):
    for param_index in range(len(params[remote_index])):
        params[remote_index][param_index].data.set_(new_params[param_index])

## Todo junto

Y ahora que conocemos cada paso, podemos ponerlo todo junto en un ciclo de entrenamiento.

In [16]:

def train(epoch):

    for data_index in range(len(remote_dataset[0])-1):
        # update remote models
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

        new_params = list()

        for param_i in range(len(params[0])):

            spdz_params = list()
            for remote_index in range(len(compute_nodes)):
                spdz_params.append((params[remote_index][param_i].data+0).fix_precision().share(bob, alice).get())

            new_param = (spdz_params[0] + spdz_params[1]).get().decode()/2
            new_params.append(new_param)

        for model in params:
            for param in model:
                param.data *= 0

        for model in models:
            model.get()

        for remote_index in range(len(compute_nodes)):
            for param_index in range(len(params[remote_index])):
                params[remote_index][param_index].data.set_(new_params[param_index])

In [17]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = models[0](data)
        test_loss += F.mse_loss(output, target.float(), size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))


In [18]:
%%time
t = time.time()
args.epochs = 10
torch.encode_timer = 0
torch.handle_call_timer = 0
torch.execute_call_timer = 0

for epoch in range(1, args.epochs + 1):
    print(epoch)
    train(epoch)
    test()

    
total_time = time.time() - t
print('Encoding', round(torch.encode_timer, 2), 's', round(torch.encode_timer/total_time*100, 2), '%')
print('Handling', round(torch.handle_call_timer, 2), 's',  round(torch.handle_call_timer/total_time*100, 2), '%')
print('Execute call', round(torch.execute_call_timer, 2), 's',  round(torch.execute_call_timer/total_time*100, 2), '%')
print('Total', round(total_time, 2), 's')

1

Test set: Average loss: 545.0826

2

Test set: Average loss: 225.9891

3

Test set: Average loss: 27.1742

4

Test set: Average loss: 20.2662

5

Test set: Average loss: 18.2706

6

Test set: Average loss: 17.2924

7

Test set: Average loss: 16.8838

8

Test set: Average loss: 16.6306

9

Test set: Average loss: 16.5443

10

Test set: Average loss: 16.6038

Encoding 0 s 0.0 %
Handling 0 s 0.0 %
Execute call 0 s 0.0 %
Total 74.08 s
CPU times: user 1min 12s, sys: 1.39 s, total: 1min 13s
Wall time: 1min 14s
