In [1]:
import logging
import argparse
import sys
import asyncio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import os

import syft as sy
from syft.workers import websocket_client
from syft.workers.websocket_client import WebsocketClientWorker
from syft.frameworks.torch.fl import utils
from my_utils import MyWebsocketClientWorker

LOG_INTERVAL = 25
logger = logging.getLogger("run_websocket_client")
# loss = nn.CrossEntropyLoss()

In [2]:
import run_websocket_client

In [3]:
hook = sy.TorchHook(torch)

In [4]:
raspi = {"host": "192.168.3.4", "hook": hook}
jetson_nano = {"host": "192.168.3.5", "hook": hook}

In [5]:
worker_a = MyWebsocketClientWorker(id='A', port=9292, **jetson_nano)

In [6]:
worker_b = MyWebsocketClientWorker(id='B', port=9292, **raspi)

In [7]:
worker_instances = [worker_a, worker_b]
client_devices = ['cuda', 'cpu']

In [8]:
for worker in worker_instances:
    worker.clear_objects_remote()

In [9]:
model = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)

In [10]:
traced_model = torch.jit.trace(model, torch.zeros([1, 400, 3], dtype=torch.float))

In [11]:
test_num = 5
for curr_round in range(1, 5+1):
    results = await asyncio.gather(
            *[
                run_websocket_client.fit_model_on_worker(
                    worker=worker,
                    traced_model=traced_model,
                    batch_size=16,
                    curr_round=curr_round,
                    max_nr_batches=10,
                    lr=0.0001,
                    device=client_device
                )
                for worker, client_device in zip(worker_instances, client_devices)
            ]
        )

    test_models = curr_round % test_num == 0 or curr_round == 1
    if test_models:
        print(results)

User-A Training start time: 2024-02-21 21:56:54.518450
User-B Training start time: 2024-02-21 21:56:55.440020
User-A Training end time: 2024-02-21 21:57:03.925978
User-B Training end time: 2024-02-21 21:57:07.674353
[('A', RecursiveScriptModule(
  original_name=Module
  (conv1): RecursiveScriptModule(original_name=Module)
  (pool): RecursiveScriptModule(original_name=Module)
  (fc1): RecursiveScriptModule(original_name=Module)
  (fc2): RecursiveScriptModule(original_name=Module)
), tensor(1.9069, requires_grad=True), 9.407528), ('B', RecursiveScriptModule(
  original_name=Module
  (conv1): RecursiveScriptModule(original_name=Module)
  (pool): RecursiveScriptModule(original_name=Module)
  (fc1): RecursiveScriptModule(original_name=Module)
  (fc2): RecursiveScriptModule(original_name=Module)
), tensor(1.9021, requires_grad=True), 12.234333)]
User-A Training start time: 2024-02-21 21:57:07.675370
User-B Training start time: 2024-02-21 21:57:08.470950
User-A Training end time: 2024-02-21 2