In [4]:

# https://onnx.ai/sklearn-onnx/index.html

# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

print(X_train)

# clr = Ridge()
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)

print(onx)

with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("rf_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]

pred_onx

# another example
# http://onnx.ai/sklearn-onnx/auto_examples/plot_complex_pipeline.html
# Note: predictions don't match!

[[5.7 2.6 3.5 1. ]
 [6.9 3.1 4.9 1.5]
 [5.  3.3 1.4 0.2]
 [6.3 2.9 5.6 1.8]
 [5.8 2.7 3.9 1.2]
 [5.6 2.8 4.9 2. ]
 [6.3 2.3 4.4 1.3]
 [7.7 3.  6.1 2.3]
 [6.6 3.  4.4 1.4]
 [6.6 2.9 4.6 1.3]
 [5.1 2.5 3.  1.1]
 [6.  3.4 4.5 1.6]
 [5.6 2.9 3.6 1.3]
 [4.8 3.  1.4 0.3]
 [4.8 3.1 1.6 0.2]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6.2 3.4 5.4 2.3]
 [6.1 3.  4.6 1.4]
 [5.5 3.5 1.3 0.2]
 [5.  2.  3.5 1. ]
 [5.1 3.5 1.4 0.2]
 [5.4 3.7 1.5 0.2]
 [4.9 3.1 1.5 0.1]
 [4.5 2.3 1.3 0.3]
 [5.5 2.3 4.  1.3]
 [4.9 2.5 4.5 1.7]
 [6.7 3.  5.2 2.3]
 [5.  3.2 1.2 0.2]
 [6.3 2.8 5.1 1.5]
 [5.6 2.7 4.2 1.3]
 [5.6 2.5 3.9 1.1]
 [5.7 4.4 1.5 0.4]
 [4.6 3.4 1.4 0.3]
 [7.9 3.8 6.4 2. ]
 [5.7 2.8 4.1 1.3]
 [5.5 4.2 1.4 0.2]
 [5.  2.3 3.3 1. ]
 [4.9 3.1 1.5 0.2]
 [6.4 3.2 4.5 1.5]
 [6.7 3.  5.  1.7]
 [5.8 2.7 5.1 1.9]
 [7.7 2.8 6.7 2. ]
 [5.3 3.7 1.5 0.2]
 [4.6 3.6 1.  0.2]
 [6.3 2.7 4.9 1.8]
 [6.8 2.8 4.8 1.4]
 [5.  3.6 1.4 0.2]
 [6.1 2.8 4.  1.3]
 [6.5 3.  5.5 1.8]
 [5.5 2.4 3.8 1.1]
 [4.9 3.6 1.4 0.1]
 [5.2 3.4 1.

array([2, 2, 0, 0, 1, 2, 0, 0, 2, 2, 1, 0, 1, 0, 2, 0, 2, 1, 1, 1, 0, 2,
       1, 1, 1, 0, 0, 2, 0, 1, 0, 2, 2, 2, 1, 2, 0, 2], dtype=int64)