# 联邦学习基础架构研究
### by 崔凌枫

### 背景：数据割裂、数据孤岛，成了限制人工智能发展的瓶颈。联邦学习是避免数据与平台垄断和尊重保护个人隐私解决思路

#### 本研究项目是为了通过研究现有federated learning 架构为大型企业提供高质量的预测模型升级服务同时为算法推理平台研发提供思路

## 什么是联邦学习

举例来说，假设有两个不同的企业 A 和 B，它们拥有不同数据。比
如，企业 A 有用户特征数据；企业 B 有产品特征数据和标注数据。这两个企业按照上述 GDPR
准则是不能粗暴地把双方数据加以合并的，因为数据的原始提供者，即他们各自的用户并没
有机会来同意这样做。假设双方各自建立一个任务模型，每个任务可以是分类或预测，而这
些任务也已经在获得数据时有各自用户的认可。那现在的问题是如何在 A 和 B 各端建立高
质量的模型。

![title](chart1.png)

如上图，第一部分：加密样本对齐。由于两家企业的用户群体并非完全重合，系统利用基于加密
的用户样本对齐技术，在 A 和 B 不公开各自数据的前提下确认双方的共有用户，并且不暴露
不互相重叠的用户。 以便联合这些用户的特征进行建模。
第二部分：加密模型训练。在确定共有用户群体后，就可以利用这些数据训练机器学习
模型。为了保证训练过程中数据的保密性，需要借助第三方协作者 C 进行加密训练。以线性
回归模型为例，训练过程可分为以下 4 步（如图 2b 所示）：
 第①步：协作者 C 把公钥分发给 A 和 B，用以对训练过程中需要交换的数据进行加
密；
 第②步：A 和 B 之间以加密形式交互用于计算梯度的中间结果；
 第③步：A 和 B 分别基于加密的梯度值进行计算，同时 B 根据其标签数据计算损失，
并把这些结果汇总给 C。C 通过汇总结果计算总梯度并将其解密。
 第④步：C 将解密后的梯度分别回传给 A 和 B；A 和 B 根据梯度更新各自模型的参
数。
迭代上述步骤直至损失函数收敛，这样就完成了整个训练过程。在样本对齐及模型训练
过程中，A 和 B 各自的数据均保留在本地，且训练中的数据交互也不会导致数据隐私泄露。
因此，双方在联邦学习的帮助下得以实现合作训练模型。 

### 基于研究故利用python syft 进行数据加密， 分别在两台设备上使用pytorch模型进行加密数据的交互后，进行预测模型的升级。验证代码如下： 

In [37]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import copy
import numpy as np
import syft as sy
from syft.frameworks.torch.fl import utils
from syft.workers.websocket_client import WebsocketClientWorker

In [38]:
#预设训练参数 训练次数100 epoch   样本数 8 btach size

class Parser:
    def __init__(self):
        self.epochs = 100
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1261427b0>

In [39]:
#导入数据 鉴于公司数据敏感性这里使用开源的波士顿房价作为演示数据 

with open('./data/boston_housing.pickle','rb') as f:
    ((x, y), (x_test, y_test)) = pickle.load(f)

x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
x_test = torch.from_numpy(x_test).float()
y_test = torch.from_numpy(y_test).float()


mean = x.mean(0, keepdim=True)
dev = x.std(0, keepdim=True)
mean[:, 3] = 0.
dev[:, 3] = 1.
x = (x - mean) / dev
x_test = (x_test - mean) / dev
train = TensorDataset(x, y)
test = TensorDataset(x_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True)

In [40]:
print(x)

tensor([[-0.2719, -0.4830, -0.4352,  ...,  1.1471,  0.4475,  0.8242],
        [-0.4029,  2.9881, -1.3323,  ..., -1.7161,  0.4314, -1.3276],
        [ 0.1248, -0.4830,  1.0271,  ...,  0.7835,  0.2203, -1.3069],
        ...,
        [-0.4015,  0.9896, -0.7406,  ..., -0.7162,  0.0793, -0.6769],
        [-0.1727, -0.4830,  1.2443,  ..., -1.7161, -0.9864,  0.4203],
        [-0.4037,  2.0414, -1.2001,  ..., -1.3071,  0.2329, -1.1525]])


In [41]:
print(y)

tensor([15.2000, 42.3000, 50.0000, 21.1000, 17.7000, 18.5000, 11.3000, 15.6000,
        15.6000, 14.4000, 12.1000, 17.9000, 23.1000, 19.9000, 15.7000,  8.8000,
        50.0000, 22.5000, 24.1000, 27.5000, 10.9000, 30.8000, 32.9000, 24.0000,
        18.5000, 13.3000, 22.9000, 34.7000, 16.6000, 17.5000, 22.3000, 16.1000,
        14.9000, 23.1000, 34.9000, 25.0000, 13.9000, 13.1000, 20.4000, 20.0000,
        15.2000, 24.7000, 22.2000, 16.7000, 12.7000, 15.6000, 18.4000, 21.0000,
        30.1000, 15.1000, 18.7000,  9.6000, 31.5000, 24.8000, 19.1000, 22.0000,
        14.5000, 11.0000, 32.0000, 29.4000, 20.3000, 24.4000, 14.6000, 19.5000,
        14.1000, 14.3000, 15.6000, 10.5000,  6.3000, 19.3000, 19.3000, 13.4000,
        36.4000, 17.8000, 13.5000, 16.5000,  8.3000, 14.3000, 16.0000, 13.4000,
        28.6000, 43.5000, 20.2000, 22.0000, 23.0000, 20.7000, 12.5000, 48.5000,
        14.6000, 13.4000, 23.7000, 50.0000, 21.7000, 39.8000, 38.7000, 22.2000,
        34.9000, 22.5000, 31.1000, 28.70

In [42]:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc4 = nn.Linear(24, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc4(x))
        x = self.fc3(x)
        return x

In [43]:
#模拟设备使用者 A 与 B （bob 与 Alice）
hook = sy.TorchHook(torch)
bob_worker = sy.VirtualWorker(hook, id="bob")
alice_worker = sy.VirtualWorker(hook, id="alice")
# kwargs_websocket = {"host": "localhost", "hook": hook}
# alice = WebsocketClientWorker(id='alice', port=8779, **kwargs_websocket)
# bob = WebsocketClientWorker(id='bob', port=8778, **kwargs_websocket)
compute_nodes = [bob_worker, alice_worker]



In [44]:
#向模拟使用者发送远程数据
remote_dataset = (list(), list())
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

In [45]:
bobs_model = Net()
alices_model = Net()
bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)

In [46]:
models = [bobs_model, alices_model]
optimizers = [bobs_optimizer, alices_optimizer]

In [47]:
model = Net()
model

Net(
  (fc1): Linear(in_features=13, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=24, bias=True)
  (fc4): Linear(in_features=24, out_features=16, bias=True)
  (fc3): Linear(in_features=16, out_features=1, bias=True)
)

In [48]:
def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    prediction = model(data)
    loss = F.mse_loss(prediction.view(-1), target)
    loss.backward()
    optimizer.step()
    return model

def train():
    for data_index in range(len(remote_dataset[0])-1):
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
        for model in models:
            model.get()
        return utils.federated_avg({
            "bob": models[0],
            "alice": models[1]
        })

In [49]:
def test(federated_model):
    federated_model.eval()
    test_loss = 0
    for data, target in test_loader:
        output = federated_model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item()
        predection = output.data.max(1, keepdim=True)[1]
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}'.format(test_loss))

In [50]:

for epoch in range(args.epochs):
    start_time = time.time()
    print(f"Epoch Number {epoch + 1}")
    federated_model = train()
    model = federated_model
    test(federated_model)
    total_time = time.time() - start_time
    print('Communication time over the network', round(total_time, 2), 's\n')

Epoch Number 1
Test set: Average loss: 615.8278
Communication time over the network 0.11 s

Epoch Number 2
Test set: Average loss: 613.6289
Communication time over the network 0.05 s

Epoch Number 3
Test set: Average loss: 610.8525
Communication time over the network 0.05 s

Epoch Number 4
Test set: Average loss: 607.9232
Communication time over the network 0.06 s

Epoch Number 5
Test set: Average loss: 604.9781
Communication time over the network 0.05 s

Epoch Number 6
Test set: Average loss: 602.0598
Communication time over the network 0.05 s

Epoch Number 7
Test set: Average loss: 599.1488
Communication time over the network 0.05 s

Epoch Number 8
Test set: Average loss: 596.2221
Communication time over the network 0.06 s

Epoch Number 9
Test set: Average loss: 593.2520
Communication time over the network 0.05 s

Epoch Number 10
Test set: Average loss: 590.2224
Communication time over the network 0.05 s

Epoch Number 11
Test set: Average loss: 587.1091
Communication time over the ne

Test set: Average loss: 41.6637
Communication time over the network 0.05 s

Epoch Number 94
Test set: Average loss: 41.4593
Communication time over the network 0.05 s

Epoch Number 95
Test set: Average loss: 41.1787
Communication time over the network 0.05 s

Epoch Number 96
Test set: Average loss: 40.9285
Communication time over the network 0.05 s

Epoch Number 97
Test set: Average loss: 40.7471
Communication time over the network 0.05 s

Epoch Number 98
Test set: Average loss: 40.4833
Communication time over the network 0.05 s

Epoch Number 99
Test set: Average loss: 40.2277
Communication time over the network 0.05 s

Epoch Number 100
Test set: Average loss: 40.0888
Communication time over the network 0.05 s

