In [None]:
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
from src.websocket_client import MyWebsocketClientWorker
from src.nn_model import ConvNet1D, loss_fn
from src.my_utils import generate_kwarg
import torch
from datetime import datetime
import asyncio

client_device_mapping_id = {
    "192.168.3.5": "AA",
    "192.168.3.6": "BB",
    "192.168.3.9": "CC",
    "192.168.3.15": "DD",
    "192.168.3.16": "EE",
}

async def send_command(commands, nodes):
    await asyncio.gather(
        *[
            n.async_command(cmd)
            for n, cmd in zip(nodes, commands)
        ]
    )

pull_and_pull_tree = {
    1: [('AA', 'BB'), ('AA', 'CC')],
    2: [('BB', 'DD'), ('CC', 'DD')]
}
hook = sy.TorchHook(torch)

# 连接设备A
aa_kwarg = generate_kwarg('AA')
node_aa = MyWebsocketClientWorker(hook=hook, **aa_kwarg)

'''分发第一层模型'''
# 初始化模型
node_aa.command(
    {
        "command_name": "model_initialization"
    }
)

In [None]:
# 开始分发模型
node_aa.command(
    {
        "command_name": "model_dissemination",
        "forward_device_id": ['BB', 'CC']
    }
)

In [None]:
node_aa.close()

'''分发第二层模型'''
command_1 = {
    "command_name": "model_dissemination",
    "forward_device_id": ['DD']
}

command_2 = {
    "command_name": "model_dissemination",
    "forward_device_id": ['EE']
}
cmds = [command_1, command_2]

bb_kwarg = generate_kwarg('BB')
cc_kwarg = generate_kwarg('CC')
node_bb = MyWebsocketClientWorker(hook, **bb_kwarg)
node_cc = MyWebsocketClientWorker(hook, **cc_kwarg)

In [None]:
await send_command(commands=cmds, nodes=[node_bb, node_cc])

In [None]:
node_bb.close()
node_cc.close()

In [None]:
'''训练部分'''
all_nodes = []
for ip, ID in client_device_mapping_id.items():
    kwargs_websocket = {"hook": hook, "host": ip, "port": 9292, "id": ID}
    all_nodes.append(MyWebsocketClientWorker(**kwargs_websocket))

all_nodes

In [None]:
command = {
    "command_name": "train",
    "dataset_key": "HAR-1"
}

await send_command(commands=[command]*5, nodes=all_nodes)

In [None]:
all_nodes

In [None]:
for node in all_nodes:
    node.close()

In [None]:
'''模型回收'''
# 连接 D E 两点
node_dd = MyWebsocketClientWorker(hook=hook,**generate_kwarg("DD"))
node_ee = MyWebsocketClientWorker(hook=hook,**generate_kwarg("EE"))

command_1 = {
    "command_name": "model_collection",
    "forward_device_id": ["BB"],
    "aggregation": True
}

command_2 = {
    "command_name": "model_collection",
    "forward_device_id": ["CC"],
    "aggregation": True
}

In [None]:
await send_command(commands=[command_1, command_2], nodes=[node_dd, node_ee])

In [None]:
node_cc.connect()
node_bb.connect()

command ={
    "command_name": "show_stored_models",
}
node_cc.command(command)
node_bb.command(command)


In [None]:
command_1 = {
    "command_name": "model_collection",
    "forward_device_id": ["AA"],
    "aggregation": False
}

command_2 = {
    "command_name": "model_collection",
    "forward_device_id": ["AA"],
    "aggregation": True
}

node_cc.command(command_1)
node_bb.command(command_2)


In [None]:
node_aa.connect()
node_aa.command(command)