In [1]:
import logging
import argparse
import sys
import asyncio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
import os

import syft as sy
from syft.workers import websocket_client
from syft.workers.websocket_client import WebsocketClientWorker
from syft.frameworks.torch.fl import utils
from my_utils import MyWebsocketClientWorker, model_to_device
import my_utils

LOG_INTERVAL = 25
logger = logging.getLogger("run_websocket_client")
# loss = nn.CrossEntropyLoss()

In [2]:
import run_websocket_client

hook = sy.TorchHook(torch)
jetson_nano = {"host": "192.168.3.5", "hook": hook}
worker_a = MyWebsocketClientWorker(id='A', port=9292, **jetson_nano)
worker_a.clear_objects_remote()

In [3]:
model_1 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_1 = torch.jit.trace(model_1, torch.zeros([1, 400, 3], dtype=torch.float))
model_2 = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model_2 = torch.jit.trace(model_2, torch.zeros([1, 400, 3], dtype=torch.float))

In [4]:
traced_model_1.conv1.weight.data[0, 0, :], traced_model_2.conv1.weight.data[0, 0, :]

(tensor([ 0.3322,  0.2962, -0.0353]), tensor([ 0.1603, -0.2100, -0.2958]))

In [5]:
(traced_model_1.conv1.weight.data[0, 0, :] + traced_model_2.conv1.weight.data[0, 0, :])/2

tensor([ 0.2462,  0.0431, -0.1656])

In [6]:
model_dict = {'a': traced_model_1, 'b': traced_model_2}

In [7]:
federated_model = run_websocket_client.ConvNet1D(input_size=400, num_classes=7)
traced_model = torch.jit.trace(federated_model, torch.zeros([1, 400, 3], dtype=torch.float))

In [8]:
aggregate_config = my_utils.AggregatedConfig(
    model_dict=model_dict,
    federated_model=traced_model
)

In [9]:
worker_a

<MyWebsocketClientWorker id:A #objects local:0 #objects remote: 0>

In [10]:
aggregate_config.send_model(worker_a)

In [11]:
worker_a

<MyWebsocketClientWorker id:A #objects local:0 #objects remote: 3>

In [12]:
aggregate_config_tuple = aggregate_config.simplify(aggregate_config.owner, aggregate_config)

In [13]:
aggregate_config_tuple

{'ID': 84767216257,
 'model_id_list': (1, (10364018464, 50439197396)),
 'federated_model_id': 8536554964}

In [14]:
ptr, ID = aggregate_config._wrap_and_send_obj(aggregate_config_tuple, worker_a)

In [15]:
worker_a

<MyWebsocketClientWorker id:A #objects local:0 #objects remote: 4>

In [16]:
ptr, ID

([CallablePointer | me:9737995215 -> A:63699163400], 63699163400)

In [17]:
# aggregate_config.detail(worker_a, aggregate_config_tuple)

In [18]:
worker_a._send_msg_and_deserialize("set_aggregate_config", ID=ID)

In [19]:
worker_a._send_msg_and_deserialize("_check_aggregate_config")

In [20]:
worker_a._send_msg_and_deserialize("model_aggregation")

In [21]:
model = aggregate_config.get(aggregate_config_tuple['federated_model_id'], worker_a).obj

In [22]:
new_model = my_utils.model_to_device(model, 'cpu')

In [26]:
new_model.conv1.weight[0, 0, :]

tensor([ 0.2462,  0.0431, -0.1656], grad_fn=<SliceBackward>)

In [27]:
new_model

ConvNet1D(
  original_name=ConvNet1D
  (conv1): Conv1d(original_name=Conv1d)
  (pool): MaxPool1d(original_name=MaxPool1d)
  (fc1): Linear(original_name=Linear)
  (fc2): Linear(original_name=Linear)
)