In [1]:
# Please ignore these variable, they only provide options for our CI system.
args = []
abort_after_one = False

# 教程：使用websockets进行联合学习，并对可能遇到的问题的可能解决方案进行联合平均

安装websocket库



## 1 启动websocket服务工作程序

每个工作程序由两部分组成，本地句柄和保存数据并执行计算的远程实例。远程部分称为Websocket服务器工作程序。

因此，首先，您需要转到cd此笔记本以及其他用于运行服务器和客户端的其他文件所在的文件夹

需要在终端中运行以下命令。

```bash
python start_websocket_servers.py
```

## 2 设置websocket客户端工作程序

导入并设置一些参数和变量。

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.fl import utils

In [4]:
import run_websocket_client as rwc

In [5]:
args = rwc.define_and_get_arguments(args=args)
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


现在，让我们实例化websocket客户端工作程序，这是我们到远程工作程序的本地访问点。请注意，如果websocket服务器工作程序未在运行，则此步骤将失败。

In [6]:
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)


[<WebsocketClientWorker id:alice #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:bob #objects local:0 #objects remote: 0>, <WebsocketClientWorker id:charlie #objects local:0 #objects remote: 0>]


## 3 准备和分发训练数据
我们将使用MNIST数据集并将数据随机分配到工作人员上。对于联合培训设置而言，这是不现实的，因为在远程培训中，数据通常通常已经可用。

我们实例化了两个FederatedDataLoader，一个用于训练，一个用于MNIST数据集的测试集。

In [7]:
#run this box only if the the next box gives pipeline error
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../官方教程/data/",
        train=True,download=True))

<torch.utils.data.dataloader.DataLoader at 0x23c29eadc70>

In [9]:
federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../../官方教程/data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../官方教程/data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)


接下来，我们需要实例化机器学习模型。这是一个具有2个卷积层和2个完全连接层的小型神经网络。它使用ReLU激活和最大池化。

In [10]:
model = rwc.Net().to(device)
print(model)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [11]:
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]

## 4 让我们开始训练
现在我们准备开始联合培训。我们将分别对每个工人进行给定数量的批次培训，然后计算所得模型的联合平均值，并计算该模型的测试准确性

In [12]:
for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches, 
                      abort_after_one=abort_after_one)
    rwc.test(model, device, test_loader)

Starting epoch 1/2
2021-05-08 07:30:49,922 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [0, 50]
2021-05-08 07:31:28,701 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [50, 100]
2021-05-08 07:32:06,354 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [100, 150]
2021-05-08 07:32:42,900 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [150, 200]
2021-05-08 07:33:20,036 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [200, 250]
2021-05-08 07:33:57,142 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [250, 300]
2021-05-08 07:34:21,885 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [300, 350]
2021-05-08 07:34:34,004 DEBUG run_websocket_client.py(l:130) - Starting training round, batches [350, 400]
2021-05-08 07:34:34,013 DEBUG run_websocket_client.py(l:142) - At least one worker ran out of data, stopping.
2021-05-08 07:34:37