In [2]:
import grpc
import sys
from concurrent import futures
import inference_pb2_grpc 
import onnx
import onnxruntime
import numpy as np
import inference_pb2 
import ezkl
import ast

In [3]:
class InferenceServer(inference_pb2_grpc.InferenceServicer):
    def RunInference(self, inferenceParams, context):
        results = self.Infer(inferenceParams.modelHash, inferenceParams.modelInput)
        return inference_pb2.InferenceResult(tx=inferenceParams.tx, node=inferenceParams.modelHash, value=str(results))
    
    def Infer(self, modelHash, modelInput):
        session = onnxruntime.InferenceSession(modelHash)
        results = session.run(curateOutputs(session), curateInputs(session, modelInput))[-1]
        return results[0][0]

In [4]:
def serve(port, maxWorkers):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=maxWorkers))
    inference_pb2_grpc.add_InferenceServicer_to_server(InferenceServer(), server)
    server.add_insecure_port("[::]:" + str(port))
    server.start()
    server.wait_for_termination()

In [5]:
def parseInput(modelInput, typeString):
    if "tensor" in typeString:
        if "float" in typeString:
            return onnxruntime.OrtValue.ortvalue_from_numpy(np.array([modelInput]).astype("float32"))
        if "string" in typeString:
            return onnxruntime.OrtValue.ortvalue_from_numpy(np.array([modelInput]).astype("string"))

def curateInputs(session, modelInput):
    inputs = {}
    sessionInputs = session.get_inputs()
    for i in range(0, len(sessionInputs)):
        param = ast.literal_eval(modelInput)[i]
        inputs[sessionInputs[i].name] = parseInput(param, sessionInputs[i].type)
    return inputs

def curateOutputs(session):
    outputs = []
    for o in session.get_outputs():
        outputs.append(o.name)
    return outputs

In [None]:
serve(port=5125, maxWorkers=100)

In [None]:
session = onnxruntime.InferenceSession("QmXQpupTphRTeXJMEz3BCt9YUF6kikcqExxPdcVoL1BBhy")
o = session.run(curateOutputs(session), curateInputs(session, "[[0.1,0.2,0.3]]"))

In [183]:
o[0][0][0]

0.051708132

In [41]:
session.get_modelmeta()

<onnxruntime.capi.onnxruntime_pybind11_state.ModelMetadata at 0x7ff3e3081b70>

In [58]:
session = onnxruntime.InferenceSession("Volatility.onnx")

In [121]:
session.get_inputs()[0].shape

[None, 3]

In [91]:
np.fromstring('[1,2]', dtype="float32", sep=',')

  np.fromstring('[1,2]', dtype="float32", sep=',')


array([], dtype=float32)

In [179]:
session.get_outputs()[0].type

'tensor(float)'