In [4]:
!pip install flask
!pip install numpy
!pip install tensorflow
!pip install flask-socketio
!pip install "python-socketio[asyncio_client]"

Collecting tensorflow
  Downloading tensorflow-2.6.0-cp39-cp39-win_amd64.whl (423.3 MB)
Collecting tensorboard~=2.6
  Downloading tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
Collecting h5py~=3.1.0
  Downloading h5py-3.1.0-cp39-cp39-win_amd64.whl (2.7 MB)
Collecting flatbuffers~=1.12.0
  Using cached flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting absl-py~=0.10
  Using cached absl_py-0.13.0-py3-none-any.whl (132 kB)
Collecting wrapt~=1.12.1
  Using cached wrapt-1.12.1.tar.gz (27 kB)
Collecting grpcio<2.0,>=1.37.0
  Downloading grpcio-1.39.0-cp39-cp39-win_amd64.whl (3.2 MB)
Collecting google-pasta~=0.2
  Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Collecting six~=1.15.0
  Using cached six-1.15.0-py2.py3-none-any.whl (10 kB)
Collecting tensorflow-estimator~=2.6
  Downloading tensorflow_estimator-2.6.0-py2.py3-none-any.whl (462 kB)
Collecting termcolor~=1.1.0
  Using cached termcolor-1.1.0.tar.gz (3.9 kB)
Collecting keras~=2.6
  Downloading keras-2.6.0-py2.py3-none-an

In [5]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

from flask import Flask, render_template
from flask_socketio import SocketIO, send, emit
import socketio

import time
from typing import List

from requests import get

In [None]:
sio = socketio.AsyncClient()
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)

computeNodes = [str]
dispatchIP = get('https://api.ipify.org').text

In [None]:
def partition(model: tf.keras.Model, partitions: List[int]) -> tf.keras.Model:
    #To slice properly, add the first and last layers to the partition list
    partitions = [0].append(partitions)
    partitions = partitions.append(len(model.layers) - 1)
    parts = [tf.keras.Model.layers]
    for i in range(1, len(partitions)):
        part = model.layers[partitions[i-1]:partitions[i]]
        parts.append(part)
    
    models = [tf.keras.Model]
    for p in range(len(parts)):
        if p == 0:
            inpt = model.input
        else:
            inpt = keras.Input(tensor=models[p-1].output)
        models.append(
            keras.Model(
                inputs=inpt,
                outputs=[layer.output for layer in parts[p]]
            )
        )
    return models

def dispatchModels(client: socketio.AsyncClient, models: List[tf.keras.Model], nodeIPs: List[str]) -> None:
    for i in range(len(models)):
        model_json = models[i].to_json()
        client.connect(nodeIPs[i], namespaces=['/recv_model'], auth={"name": "dispatcher"})
        if i != len(models) - 1:
            nextNode = nodeIPs[i + 1]
        else:
            # Reached the end of the nodes, the last node needs to point back to the dispatcher
            nextNode = dispatchIP
        client.emit("dispatch", data={"model": model_json, "nextNode": nextNode}, namespace=['/recv_model'])
        client.disconnect()
    
    client.connect(nodeIPs[0], namespaces=['/recv_data'], auth={"name": "dispatcher"})

def startDistEdgeInference(client: socketio.AsyncClient, model_input: tf.Tensor):
    client.emit("data", data=model_input.numpy().tolist(), namespaces=['/recv_data'], auth={"name": "dispatcher"})

In [None]:
in_list = [[1, 2], [3, 4], [5, 6]]
inpt = tf.convert_to_tensor(in_list)

layer_a = tf.keras.layers.Dense(3, kernel_initializer=tf.constant_initializer(1.))
layer_b = tf.keras.layers.Dense(1, kernel_initializer=tf.constant_initializer(1.))
layer_c = tf.keras.layers.Dense(2, kernel_initializer=tf.constant_initializer(1.))

In [None]:
model = tf.keras.Sequential()
model.add(layer_a)
model.add(layer_b)
model.add(layer_c)
out = model(inpt)
print("Local run: {}".format(out))
models_to_dispatch = partition(model, [1])

In [None]:
@socketio.on('from node', namespace="/net")
def handle_data(node, data):
    print("Data handled")
    if node == "part1":
        print("Sending to part2")
        socketio.emit('to part2', {"server": data}, namespace="/net")
    elif node == "part2":
        print("Done distributing, result is {}".format(data))
        exit(0)

@socketio.on('connect', namespace="/recv_model")
def client_connect(auth):
    print("Auth: ", auth)
    print("{} has connected".format(auth["name"]))
    if auth["name"] == "part2":
        print("Emitting to part1")
        socketio.emit('to part1', {"server": in_list}, namespace="/net")

In [None]:
if __name__ == '__main__':
    dispatchModels(sio, models_to_dispatch, computeNodes)
    startDistEdgeInference(inpt)
    socketio.run(app, port=5000)