In [1]:
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
from src.websocket_client import MyWebsocketClientWorker
from src.nn_model import ConvNet1D, loss_fn
import torch
from datetime import datetime
import asyncio

# 连接测试部分

In [None]:
client_device_mapping_id = {
    "192.168.3.5": "AA",
    "192.168.3.6": "BB",
    "192.168.3.9": "CC",
    "192.168.3.15": "DD",
    "192.168.3.16": "EE",
}

hook = sy.TorchHook(torch)

all_nodes = []
for ip, ID in client_device_mapping_id.items():
    kwargs_websocket = {"hook": hook, "host": ip, "port": 9292, "id": ID}
    all_nodes.append(WebsocketClientWorker(**kwargs_websocket))

for node in all_nodes:
    node.clear_objects_remote()

In [None]:
model = ConvNet1D(input_size=400, num_classes=7)
traced_model = torch.jit.trace(model, torch.zeros([1, 400, 3], dtype=torch.float))

In [None]:
from src.train_config import MyTrainConfig

train_config = MyTrainConfig(
    model=traced_model,
    loss_fn=loss_fn,
    batch_size=32,
    shuffle=True,
    epochs=5,
    optimizer='SGD'
)

In [None]:
start = datetime.now()
train_config._wrap_and_send_obj(traced_model, all_nodes[0])
train_config._wrap_and_send_obj(loss_fn, all_nodes[0])
end = datetime.now()
(end - start).total_seconds()

# 中继节点测试部分

In [None]:
hook = sy.TorchHook(torch)

aa_kwarg = {"hook": hook, "host": "192.168.3.5", "port": 9292, "id": "AA"}
node_aa = WebsocketClientWorker(**aa_kwarg)

In [None]:
forward_device_mapping_id = {
    "192.168.3.6": "BB",
}
node_aa._send_msg_and_deserialize(command_name="connect_child_nodes", forward_device_mapping_id=forward_device_mapping_id)


In [None]:
start = datetime.now()
node_aa._send_msg_and_deserialize(command_name="command", command=dict(command_name="test"))
end = datetime.now()
(end - start).total_seconds()

In [None]:
node_aa._send_msg_and_deserialize(command_name="close_child_nodes")

# 异步测试部分

In [2]:
client_device_mapping_id = {
    "192.168.3.5": "AA",
    "192.168.3.6": "BB",
    "192.168.3.9": "CC",
    "192.168.3.15": "DD",
    "192.168.3.16": "EE",
}

hook = sy.TorchHook(torch)

all_nodes = []
for ip, ID in client_device_mapping_id.items():
    kwargs_websocket = {"hook": hook, "host": ip, "port": 9292, "id": ID}
    all_nodes.append(MyWebsocketClientWorker(**kwargs_websocket))

for node in all_nodes:
    node.clear_objects_remote()

command = {"command_name": "test"}

async def main():
    await asyncio.gather(
        *[
            n.async_command(command)
            for n in all_nodes
        ]
    )

In [3]:
print(datetime.now())
await main()
print(datetime.now())

2024-06-12 13:33:19.877617
2024-06-12 13:33:25.065539
