In [1]:
import logging
import argparse
import sys
import asyncio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import os

import syft as sy
from syft.workers import websocket_client
from syft.workers.websocket_client import WebsocketClientWorker
from syft.frameworks.torch.fl import utils
from my_utils import MyWebsocketClientWorker, model_to_device, model_federated
import my_utils

LOG_INTERVAL = 25
logger = logging.getLogger("run_websocket_client")
# loss = nn.CrossEntropyLoss()

In [2]:
import run_websocket_client

hook = sy.TorchHook(torch)
raspi = {"host": "192.168.3.4", "hook": hook}
jetson_nano = {"host": "192.168.3.5", "hook": hook}
worker_a = MyWebsocketClientWorker(id='A', port=9292, **jetson_nano)
worker_b = MyWebsocketClientWorker(id='B', port=9292, **raspi)

In [3]:
worker_instances = [worker_a, worker_b]
client_devices = ['cuda', 'cpu']
for worker in worker_instances:
    worker.clear_objects_remote()

In [4]:
model_1 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_1 = torch.jit.trace(model_1, torch.zeros([1, 400, 3], dtype=torch.float))
model_2 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_2 = torch.jit.trace(model_2, torch.zeros([1, 400, 3], dtype=torch.float))
model_3 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_3 = torch.jit.trace(model_3, torch.zeros([1, 400, 3], dtype=torch.float))

In [5]:
traced_model_1.conv1.weight.data[0,0,:]

tensor([0.0475, 0.1686, 0.1387])

In [6]:
traced_model_2.conv1.weight.data[0,0,:]

tensor([-0.2263,  0.0264,  0.2149])

In [7]:
traced_model_1.conv1.weight.data[0,0,:] + traced_model_2.conv1.weight.data[0,0,:]

tensor([-0.1788,  0.1950,  0.3536])

In [8]:
federated_model = my_utils.model_federated({'a':traced_model_1, 'b':traced_model_2}, worker_a)

In [9]:
print(federated_model.conv1.weight.data[0,0,:].get())

tensor([-0.2263,  0.0264,  0.2149])


In [10]:
federated_model.conv1.weight.data.clone().get()

Parameter containing:
tensor([[[-2.2630e-01,  2.6393e-02,  2.1494e-01],
         [ 2.7387e-01,  7.8323e-04,  2.9508e-01],
         [ 4.4348e-02,  1.2035e-01, -2.3220e-01]],

        [[-1.5215e-03, -2.8472e-01, -4.9834e-02],
         [ 1.5670e-01, -2.9517e-01,  6.4288e-02],
         [-1.3372e-01, -1.3467e-01, -2.2655e-01]],

        [[ 4.5637e-02,  1.9559e-01, -5.3175e-02],
         [-2.4445e-01, -1.7584e-01,  2.7691e-01],
         [ 2.0775e-01, -1.5096e-01, -2.4017e-01]],

        [[ 3.2251e-01, -2.8038e-01, -1.5793e-01],
         [-5.3925e-02, -7.1292e-02,  7.4407e-02],
         [-1.7679e-01,  8.4727e-02,  2.9829e-04]],

        [[ 1.4561e-01,  2.4836e-01,  3.0048e-01],
         [ 2.0142e-01,  1.6316e-01, -2.3710e-01],
         [-3.1256e-01, -2.1548e-01, -1.5248e-01]],

        [[ 1.0724e-01, -1.3544e-01, -2.7287e-01],
         [-2.6351e-01, -2.7543e-01,  8.2919e-02],
         [-3.8436e-02, -1.1612e-01,  2.1804e-01]],

        [[ 1.2342e-01,  9.9680e-02, -9.7272e-02],
         [ 1.943