In [1]:
import test2
import my_utils
import torch
import run_websocket_client
import asyncio
import syft as sy
from syft.workers.virtual import VirtualWorker
from datetime import datetime
from syft.workers.websocket_client import WebsocketClientWorker
from syft.generic.pointers.object_wrapper import ObjectWrapper
from syft.messaging.message import ObjectMessage

In [2]:
hook = sy.TorchHook(torch)

model = my_utils.ConvNet1D(input_size=400, num_classes=7)
traced_model = torch.jit.trace(model, torch.zeros([1, 400, 3], dtype=torch.float))
train_config = test2.MyTrainConfig(
    model=traced_model,
    loss_fn=run_websocket_client.loss_fn,
    batch_size=32,
    shuffle=True,
    epochs=5,
    optimizer='SGD'
)

In [3]:
worker_a = my_utils.MyWebsocketClientWorker(hook=hook, host='192.168.3.5', port=9292, id="AA")
worker_b = my_utils.MyWebsocketClientWorker(hook=hook, host='192.168.3.6', port=9292, id="BB")
worker_c = my_utils.MyWebsocketClientWorker(hook=hook, host='192.168.3.9', port=9292, id="CC")
worker_d = my_utils.MyWebsocketClientWorker(hook=hook, host='192.168.3.15', port=9292, id="DD")
worker_e = my_utils.MyWebsocketClientWorker(hook=hook, host='192.168.3.16', port=9292, id="EE")

worker_list = [worker_a, worker_b, worker_c, worker_d, worker_e]

In [4]:
for worker in worker_list:
    worker.clear_objects_remote()

In [5]:
async def main():
    start = datetime.now()
    await asyncio.gather(
        *[
            train_config.async_wrap_and_send(traced_model, worker)
            for worker in worker_list
        ]
    )
    print(f'{(datetime.now()-start).total_seconds()}')

def send_one_by_one():
    start = datetime.now()
    for worker in worker_list:
        obj_id = sy.ID_PROVIDER.pop()
        print(worker.id, obj_id)
        obj_with_id = ObjectWrapper(id=obj_id, obj=traced_model)
        print(f"User-{worker.id} sending start:{datetime.now()}")
        train_config.owner.send(obj_with_id, worker)
        print(f"User-{worker.id} sending end:{datetime.now()}")
    print(f"{(datetime.now()-start).total_seconds()}")

In [6]:
await main()

DD 12525256734
User-DD sending start:2024-04-29 15:18:04.095944
User-DD sending end:2024-04-29 15:18:05.068342
BB 92359900092
User-BB sending start:2024-04-29 15:18:05.098518
User-BB sending end:2024-04-29 15:18:06.066663
EE 37107498580
User-EE sending start:2024-04-29 15:18:06.097061
User-EE sending end:2024-04-29 15:18:07.069747
AA 16798777998
User-AA sending start:2024-04-29 15:18:07.099349
User-AA sending end:2024-04-29 15:18:08.081867
CC 98440769202
User-CC sending start:2024-04-29 15:18:08.112037
User-CC sending end:2024-04-29 15:18:09.100431
User-BB receive:2024-04-29 15:18:09.101938
User-EE receive:2024-04-29 15:18:09.101938
User-AA receive:2024-04-29 15:18:09.101938
User-DD receive:2024-04-29 15:18:09.101938
User-CC receive:2024-04-29 15:18:09.691032
5.68475


In [7]:
worker_a.objects_count_remote(), worker_b.objects_count_remote()

(1, 1)

In [8]:
send_one_by_one()

AA 26834778771
User-AA sending start:2024-04-29 15:18:09.763421
User-AA sending end:2024-04-29 15:18:10.462616
BB 23109114026
User-BB sending start:2024-04-29 15:18:10.462616
User-BB sending end:2024-04-29 15:18:11.143782
CC 34328767684
User-CC sending start:2024-04-29 15:18:11.144383
User-CC sending end:2024-04-29 15:18:11.822224
DD 23144463686
User-DD sending start:2024-04-29 15:18:11.822224
User-DD sending end:2024-04-29 15:18:12.471078
EE 92119353976
User-EE sending start:2024-04-29 15:18:12.471078
User-EE sending end:2024-04-29 15:18:13.120463
3.357042


In [9]:
worker_a.objects_count_remote(), worker_b.objects_count_remote()

(2, 2)