In [None]:
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
from src.websocket_client import MyWebsocketClientWorker
from src.nn_model import ConvNet1D, loss_fn
from src.my_utils import generate_kwarg
import torch
from datetime import datetime
import asyncio

client_device_mapping_id = {
    "192.168.3.5": "AA",
    "192.168.3.6": "BB",
    "192.168.3.9": "CC",
    "192.168.3.15": "DD",
    "192.168.3.16": "EE",
}

async def send_command(commands, nodes):
    await asyncio.gather(
        *[
            n.async_command(cmd)
            for n, cmd in zip(nodes, commands)
        ]
    )

pull_and_pull_tree = {
    1: [('AA', 'BB'), ('AA', 'CC')],
    2: [('BB', 'DD'), ('CC', 'DD')]
}
hook = sy.TorchHook(torch)

# 连接设备A
aa_kwarg = generate_kwarg('AA')
node_aa = MyWebsocketClientWorker(hook=hook, **aa_kwarg)

'''分发第一层模型'''
# 初始化模型
node_aa.command(
    {
        "command_name": "model_initialization"
    }
)

# 开始分发模型
node_aa.command(
    {
        "command_name": "model_dissemination",
        "forward_device_mapping_id": {
            "192.168.3.6": "BB",
            "192.168.3.9": "CC"}
    }
)
node_aa.close()

'''分发第二层模型'''
command_1 = {
    "command_name": "model_dissemination",
    "forward_device_mapping_id": {
        "192.168.3.15": "DD"
    }
}

command_2 = {
    "command_name": "model_dissemination",
    "forward_device_mapping_id": {
        "192.168.3.16": "EE"
    }
}
cmds = [command_1, command_2]

bb_kwarg = generate_kwarg('BB')
cc_kwarg = generate_kwarg('CC')
node_bb = MyWebsocketClientWorker(hook, **bb_kwarg)
node_cc = MyWebsocketClientWorker(hook, **cc_kwarg)

await send_command(commands=cmds, nodes=[node_bb, node_cc])

node_bb.close()
node_cc.close()

In [None]:
'''训练部分'''
all_nodes = []
for ip, ID in client_device_mapping_id.items():
    kwargs_websocket = {"hook": hook, "host": ip, "port": 9292, "id": ID}
    all_nodes.append(MyWebsocketClientWorker(**kwargs_websocket))

command = {
    "command_name": "train",
    "dataset_key": "HAR-1"
}

await send_command(commands=[command], nodes=all_nodes)