In [1]:
#!pip install flask
#!pip install numpy
#!pip install tensorflow
#!pip install flask-socketio
#!pip install "python-socketio[async_client]"
#!pip install simple-websocket

In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

from flask import Flask
from flask_socketio import SocketIO, send, emit
import socketio
import asyncio

from typing import List

from requests import get

In [3]:
#sio = socketio.AsyncClient(logger=True, engineio_logger=True)
cli = socketio.Client(logger=True, engineio_logger=True)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
server = SocketIO(app, logger=True)

computeNodes = ["http://127.0.0.1:5001/", "http://127.0.0.1:5002/"]
#dispatchIP = get('https://api.ipify.org').text
dispatchIP = "http://127.0.0.1:5000/"
print("DispatchIP:", dispatchIP)

DispatchIP: http://127.0.0.1:5000/


In [4]:
def partition(model: tf.keras.Model, partit: List[int]) -> tf.keras.Model:
    #To slice properly, add the first and last layers to the partition list
    partitions = [0]
    partitions.extend(partit)
    partitions.append(len(model.layers) - 1)
    parts = []
    for i in range(1, len(partitions)):
        part = model.layers[partitions[i-1]:partitions[i]]
        parts.append(part)
    
    models = []
    for p in range(len(parts)):
        if p == 0:
            inpt = keras.Input(tensor=model.input)
            print(inpt)
        else:
            inpt = keras.Input(tensor=models[p-1].output)
        print(inpt)
        print([layer.output for layer in parts[p]])
        models.append(
            keras.Model(
                inputs=inpt,
                outputs=[layer.output for layer in parts[p]]
            )
        )
        print(models)
    return models

async def dispatchModels(models: List[tf.keras.Model], nodeIPs: List[str]) -> None:
    for i in range(len(models)):
        client = socketio.AsyncClient(logger=True, engineio_logger=True)
        model_json = models[i].to_json()
        print(nodeIPs[i])
        await client.connect(nodeIPs[i], namespaces=['/recv_model'], auth={"name": "dispatcher"})
        if i != len(models) - 1:
            nextNode = nodeIPs[i + 1]
        else:
            # Reached the end of the nodes, the last node needs to point back to the dispatcher
            nextNode = dispatchIP
        
        print("Reached emit")
        await client.emit("dispatch", data=(model_json, nextNode), namespace='/recv_model')
    await client.sleep(2)

def startDistEdgeInference(client, model_input: tf.Tensor, send_to: str):
    print("Starting inference")
    client.connect(send_to, namespaces=['/recv_data'])
    client.emit("data", data=model_input.numpy().tolist(), namespace='/recv_data')

In [5]:
in_list = [[1, 2], [3, 4], [5, 6]]
inpt = tf.convert_to_tensor(in_list)

layer_a = tf.keras.layers.Dense(3, kernel_initializer=tf.constant_initializer(1.))
layer_b = tf.keras.layers.Dense(1, kernel_initializer=tf.constant_initializer(1.))
layer_c = tf.keras.layers.Dense(2, kernel_initializer=tf.constant_initializer(1.))

In [6]:
model = tf.keras.Sequential()
model.add(layer_a)
model.add(layer_b)
model.add(layer_c)
out = model(inpt)
print("Local run: {}".format(out))
models_to_dispatch = partition(model, [1])

Local run: [[ 9.  9.]
 [21. 21.]
 [33. 33.]]
KerasTensor(type_spec=TensorSpec(shape=(3, 2), dtype=tf.int32, name='dense_input'), name='dense_input', description="created by layer 'input_1'")
KerasTensor(type_spec=TensorSpec(shape=(3, 2), dtype=tf.int32, name='dense_input'), name='dense_input', description="created by layer 'input_1'")
[<KerasTensor: shape=(3, 3) dtype=float32 (created by layer 'dense')>]
[<keras.engine.functional.Functional object at 0x000001882079A970>]
KerasTensor(type_spec=TensorSpec(shape=(3, 3), dtype=tf.float32, name=None), name='dense/BiasAdd:0', description="created by layer 'input_2'")
[<KerasTensor: shape=(3, 1) dtype=float32 (created by layer 'dense_1')>]
[<keras.engine.functional.Functional object at 0x000001882079A970>, <keras.engine.functional.Functional object at 0x00000188207FB9A0>]


In [7]:
@server.on('data', namespace="/recv_data")
def got_result(data):
    print("Done distributing, result is {}".format(data))

@server.on('connect', namespace="/recv_data")
def data_connect():
    print("Previous node connected")
    startDistEdgeInference(cli, inpt, computeNodes[0])

In [8]:
async def main():
    await dispatchModels(models_to_dispatch, computeNodes)
    #await startDistEdgeInference(sio, inpt, computeNodes[0])

In [9]:
if __name__ == '__main__':
    await main()
    server.run(app, debug=True, port=5000, use_reloader=False)

Signal handler is unsupported
Attempting polling connection to http://127.0.0.1:5001/socket.io/?transport=polling&EIO=4
Polling connection accepted with {'sid': 'vMczZ05sacIllQWjAAAA', 'upgrades': ['websocket'], 'pingTimeout': 20000, 'pingInterval': 25000}
Engine.IO connection established
Sending packet MESSAGE data 0/recv_model,{"name":"dispatcher"}
Attempting WebSocket upgrade to ws://127.0.0.1:5001/socket.io/?transport=websocket&EIO=4
WebSocket upgrade was successful
Received packet NOOP data 
Received packet MESSAGE data 0/recv_model,{"sid":"BzTVokXOrQlRWsqRAAAB"}
Namespace /recv_model is connected
Emitting event "dispatch" [/recv_model]
Sending packet MESSAGE data 2/recv_model,["dispatch","{\"class_name\": \"Functional\", \"config\": {\"name\": \"model\", \"layers\": [{\"class_name\": \"InputLayer\", \"config\": {\"batch_input_shape\": [3, 2], \"dtype\": \"int32\", \"sparse\": false, \"ragged\": false, \"name\": \"input_1\"}, \"name\": \"input_1\", \"inbound_nodes\": []}, {\"class

http://127.0.0.1:5001/
Reached emit
http://127.0.0.1:5002/


Received packet NOOP data 
Received packet MESSAGE data 0/recv_model,{"sid":"6l1lO-7c30D9jFDrAAAB"}
Namespace /recv_model is connected
Emitting event "dispatch" [/recv_model]
Sending packet MESSAGE data 2/recv_model,["dispatch","{\"class_name\": \"Functional\", \"config\": {\"name\": \"model_1\", \"layers\": [{\"class_name\": \"InputLayer\", \"config\": {\"batch_input_shape\": [3, 3], \"dtype\": \"float32\", \"sparse\": false, \"ragged\": false, \"name\": \"input_2\"}, \"name\": \"input_2\", \"inbound_nodes\": []}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 1, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"Constant\", \"config\": {\"value\": 1.0}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}, \"name\"

Reached emit
 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [26/Sep/2021 23:58:50] "GET /socket.io/?transport=polling&EIO=4&t=1632725928.6653295 HTTP/1.1" 200 -
Attempting polling connection to http://127.0.0.1:5001/socket.io/?transport=polling&EIO=4
Polling connection accepted with {'sid': 'huNLflh9BmomUPrJAAAC', 'upgrades': ['websocket'], 'pingTimeout': 20000, 'pingInterval': 25000}
Engine.IO connection established
Sending packet MESSAGE data 0/recv_data,
Attempting WebSocket upgrade to ws://127.0.0.1:5001/socket.io/?transport=websocket&EIO=4
WebSocket upgrade was successful
Received packet NOOP data 
Received packet MESSAGE data 0/recv_data,{"sid":"Jv_Fle3t-fIR1422AAAD"}
Namespace /recv_data is connected
Emitting event "data" [/recv_data]
Sending packet MESSAGE data 2/recv_data,["data",[[1,2],[3,4],[5,6]]]
received event "data" from VIpPzdI2GnCanvQNAAAB [/recv_data]


Previous node connected
Starting inference
Done distributing, result is [[9.0], [21.0], [33.0]]


Received packet PING data 
Sending packet PONG data 
