In [11]:
import joblib
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, StringTensorType
import pandas as pd

load the scikit model

In [12]:
clf = joblib.load('../model.pkl')

some auxiliary function to help converting inputs to ONNX format

In [13]:
def convert_sklearn_onnx(clf, x_sample):
    inputs = convert_dataframe_schema(x_sample)
    onnx_model = convert_sklearn(clf, 'model_pipeline', inputs, target_opset=11)

    with open("model.onnx", "wb") as f:
        f.write(onnx_model.SerializeToString())

    return onnx_model


def convert_dataframe_schema(df, drop=None):
    inputs = []
    for k, v in zip(df.columns, df.dtypes):
        if drop is not None and k in drop:
            continue
        if v == 'int64':
            t = Int64TensorType([None, 1])
        elif v == 'float64':
            t = FloatTensorType([None, 1])
        else:
            t = StringTensorType([None, 1])
        inputs.append((k, t))
    return inputs


def onnx_input(x):
    inputs = {c: x[c].values for c in x.columns}
    for k in inputs:
        inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1))
    return inputs

In [14]:
x = pd.read_csv('../data/adult_test.csv').drop(['income'], axis=1)
x['age']=x['age'].astype(int)
x['hours_per_week']=x['hours_per_week'].astype(int)
x['capital_gain']=x['capital_gain'].astype(int)
x['capital_loss']=x['capital_loss'].astype(int)

model conversion (needs an input example)

In [16]:
onnx_model = convert_sklearn_onnx(clf, x)

Cool model Viz tool that shows computational graph

In [17]:
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
pydot_graph = GetPydotGraph(onnx_model.graph, name=onnx_model.graph.name, rankdir="TP",
                            node_producer=GetOpNodeProducer("docstring"))
pydot_graph.write_dot("graph.dot")

In [18]:
import os
os.system('dot -O -Tpng graph.dot')

0

inference time (onnxruntime is quite lightweight)

In [19]:
import onnxruntime as rt

In [20]:
sess = rt.InferenceSession('./model.onnx')


onnx_outputs = sess.run(None, onnx_input(x))


In [21]:
onnx_outputs

[array(['>50K', '<=50K', '<=50K', ..., '<=50K', '<=50K', '>50K'],
       dtype=object),
 [{'<=50K': 0.45585644245147705, '>50K': 0.544143557548523},
  {'<=50K': 0.62424635887146, '>50K': 0.37575361132621765},
  {'<=50K': 0.7707868218421936, '>50K': 0.2292131632566452},
  {'<=50K': 0.7363759279251099, '>50K': 0.26362407207489014},
  {'<=50K': 0.6868534088134766, '>50K': 0.31314659118652344},
  {'<=50K': 0.5918997526168823, '>50K': 0.40810027718544006},
  {'<=50K': 0.7908186912536621, '>50K': 0.20918133854866028},
  {'<=50K': 0.7654048204421997, '>50K': 0.2345951795578003},
  {'<=50K': 0.9903500080108643, '>50K': 0.009650002233684063},
  {'<=50K': 0.9543892741203308, '>50K': 0.04561075195670128},
  {'<=50K': 0.896947979927063, '>50K': 0.10305202007293701},
  {'<=50K': 0.9837356209754944, '>50K': 0.016264351084828377},
  {'<=50K': 0.8453903794288635, '>50K': 0.15460963547229767},
  {'<=50K': 0.46343404054641724, '>50K': 0.5365659594535828},
  {'<=50K': 0.9279829263687134, '>50K': 0.072017