In [1]:
#!pip install flask
#!pip install numpy
#!pip install tensorflow
#!pip install flask-socketio
!pip install "python-socketio[client]"

Collecting websocket-client>=0.54.0
  Downloading websocket_client-1.2.1-py2.py3-none-any.whl (52 kB)
Installing collected packages: websocket-client
Successfully installed websocket-client-1.2.1


In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

from flask import Flask
from flask_socketio import SocketIO, send, emit
import socketio

from typing import List

from requests import get

In [2]:
sio = socketio.Client()
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)

computeNodes = ["http://localhost:5001", "http://localhost:5002"]
#dispatchIP = get('https://api.ipify.org').text
dispatchIP = "http://localhost:5000"
print("DispatchIP:", dispatchIP)

DispatchIP: http://localhost:5000


In [3]:
def partition(model: tf.keras.Model, partit: List[int]) -> tf.keras.Model:
    #To slice properly, add the first and last layers to the partition list
    partitions = [0]
    partitions.extend(partit)
    partitions.append(len(model.layers) - 1)
    parts = []
    for i in range(1, len(partitions)):
        part = model.layers[partitions[i-1]:partitions[i]]
        parts.append(part)
    
    models = []
    for p in range(len(parts)):
        if p == 0:
            inpt = keras.Input(tensor=model.input)
            print(inpt)
        else:
            inpt = keras.Input(tensor=models[p-1].output)
        print(inpt)
        print([layer.output for layer in parts[p]])
        models.append(
            keras.Model(
                inputs=inpt,
                outputs=[layer.output for layer in parts[p]]
            )
        )
        print(models)
    return models

def dispatchModels(client, models: List[tf.keras.Model], nodeIPs: List[str]) -> None:
    for i in range(len(models)):
        model_json = models[i].to_json()
        print(nodeIPs[i])
        client.connect(nodeIPs[i], namespaces=['/recv_model'], auth={"name": "dispatcher"}, wait_timeout=5)
        if i != len(models) - 1:
            nextNode = nodeIPs[i + 1]
        else:
            # Reached the end of the nodes, the last node needs to point back to the dispatcher
            nextNode = dispatchIP
        
        print("Reached emit")
        client.emit("dispatch", data=(model_json, nextNode), namespace='/recv_model')
        client.sleep(1)
        print("Reached disconnect")
        client.disconnect()
        print(client.connected)
    
    client.sleep(1)
    client.connect(nodeIPs[0], namespaces=['/recv_data'], auth={"name": "dispatcher"}, wait_timeout=5)

def startDistEdgeInference(client, model_input: tf.Tensor):
    print("Starting inference")
    client.emit("data", data=model_input.numpy().tolist(), namespace='/recv_data')

In [4]:
in_list = [[1, 2], [3, 4], [5, 6]]
inpt = tf.convert_to_tensor(in_list)

layer_a = tf.keras.layers.Dense(3, kernel_initializer=tf.constant_initializer(1.))
layer_b = tf.keras.layers.Dense(1, kernel_initializer=tf.constant_initializer(1.))
layer_c = tf.keras.layers.Dense(2, kernel_initializer=tf.constant_initializer(1.))

In [5]:
model = tf.keras.Sequential()
model.add(layer_a)
model.add(layer_b)
model.add(layer_c)
out = model(inpt)
print("Local run: {}".format(out))
models_to_dispatch = partition(model, [1])

Local run: [[ 9.  9.]
 [21. 21.]
 [33. 33.]]
KerasTensor(type_spec=TensorSpec(shape=(3, 2), dtype=tf.int32, name='dense_input'), name='dense_input', description="created by layer 'input_1'")
KerasTensor(type_spec=TensorSpec(shape=(3, 2), dtype=tf.int32, name='dense_input'), name='dense_input', description="created by layer 'input_1'")
[<KerasTensor: shape=(3, 3) dtype=float32 (created by layer 'dense')>]
[<keras.engine.functional.Functional object at 0x0000017C0F6E2160>]
KerasTensor(type_spec=TensorSpec(shape=(3, 3), dtype=tf.float32, name=None), name='dense/BiasAdd:0', description="created by layer 'input_2'")
[<KerasTensor: shape=(3, 1) dtype=float32 (created by layer 'dense_1')>]
[<keras.engine.functional.Functional object at 0x0000017C0F6E2160>, <keras.engine.functional.Functional object at 0x0000017C0F7701C0>]


In [6]:
@socketio.on('data', namespace="/recv_data")
def got_result(data):
    print("Done distributing, result is {}".format(data))

@socketio.on('connect', namespace="/recv_data")
def connect():
    print("Previous node connected")

In [7]:
if __name__ == '__main__':
    dispatchModels(sio, models_to_dispatch, computeNodes)
    startDistEdgeInference(sio, inpt)
    socketio.run(app, port=5000)

http://localhost:5001
Reached emit
Reached disconnect
False
http://localhost:5002


Exception in thread Thread-6:
Traceback (most recent call last):
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\site-packages\engineio\client.py", line 685, in _write_loop
    self.queue.task_done()
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\queue.py", line 75, in task_done
    raise ValueError('task_done() called too many times')
ValueError: task_done() called too many times


Reached emit
Reached disconnect
False


Exception in thread Thread-9:
Traceback (most recent call last):
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\site-packages\engineio\client.py", line 685, in _write_loop
    self.queue.task_done()
  File "C:\Users\arjun\AppData\Local\Programs\Python\Python39\lib\queue.py", line 75, in task_done
    raise ValueError('task_done() called too many times')
ValueError: task_done() called too many times
WebSocket transport not available. Install simple-websocket for improved performance.


 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
The WebSocket transport is not available, you must install a WebSocket server that is compatible with your async mode to enable it. See the documentation for details. (further occurrences of this error will be logged with level INFO)
127.0.0.1 - - [15/Sep/2021 17:07:20] "GET /socket.io/?transport=polling&EIO=4&t=1631750836.1289923 HTTP/1.1" 200 -
127.0.0.1 - - [15/Sep/2021 17:07:22] "POST /socket.io/?transport=polling&EIO=4&sid=RrJsvictq6GGL-6YAAAA HTTP/1.1" 200 -
127.0.0.1 - - [15/Sep/2021 17:07:22] "GET /socket.io/?transport=polling&EIO=4&sid=RrJsvictq6GGL-6YAAAA&t=1631750840.2613733 HTTP/1.1" 200 -


Previous node connected
