In [1]:
import logging
import argparse
import sys
import asyncio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import os

import syft as sy
from syft.workers import websocket_client
from syft.workers.websocket_client import WebsocketClientWorker
from syft.frameworks.torch.fl import utils
from my_utils import MyWebsocketClientWorker, model_to_device
import my_utils

LOG_INTERVAL = 25
logger = logging.getLogger("run_websocket_client")
# loss = nn.CrossEntropyLoss()

In [2]:
import run_websocket_client

hook = sy.TorchHook(torch)
jetson_nano = {"host": "192.168.3.5", "hook": hook}
worker_a = MyWebsocketClientWorker(id='A', port=9292, **jetson_nano)
worker_a.clear_objects_remote()

In [3]:
model_1 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_1 = torch.jit.trace(model_1, torch.zeros([1, 400, 3], dtype=torch.float))
model_2 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_2 = torch.jit.trace(model_2, torch.zeros([1, 400, 3], dtype=torch.float))

In [4]:
traced_model_1.conv1.weight.data[0, 0, :], traced_model_2.conv1.weight.data[0, 0, :]

(tensor([-0.2016, -0.0483,  0.1988]), tensor([-0.1480, -0.0651,  0.0795]))

In [5]:
(traced_model_1.conv1.weight.data[0, 0, :] + traced_model_2.conv1.weight.data[0, 0, :])/2

tensor([-0.1748, -0.0567,  0.1391])

In [6]:
model_dict = {'a': traced_model_1, 'b': traced_model_2}

In [7]:
result_model = run_websocket_client.aggregate(model_dict, worker_a)

In [10]:
result_model.conv1.weight.data[0, 0, :]

tensor([-0.1748, -0.0567,  0.1391])