In [None]:
!pip install skl2onnx

## Import Libraries

In [1]:
import time
import skl2onnx
import onnxruntime as ort
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
import numpy as np

## Load Dataset & Preprocessing

In [2]:
dataset = load_digits()
X = dataset.images.reshape(len(dataset.images),-1)
Y = dataset.target
x_train , x_test , y_train , y_test = train_test_split(X,Y,test_size=0.2)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

## Scikit-Learn MLPClassifier

In [3]:
sklearn_model = MLPClassifier(max_iter=100, random_state=123)
sklearn_model.fit(x_train,y_train)



## Convert Scikit-Learn Model To ONNX

In [4]:
onnx_model = skl2onnx.to_onnx(sklearn_model, x_train[0])

with open("sklearn_to_onnx_model.onnx", "wb") as f:
  f.write(onnx_model.SerializeToString())

## Use ONNX Model

In [5]:
onnx_model_session = ort.InferenceSession(
    "sklearn_to_onnx_model.onnx",
    providers = ["CPUExecutionProvider"]
)
input_name = onnx_model_session.get_inputs()[0].name
output_name = onnx_model_session.get_outputs()[0].name
print(input_name, output_name)

X output_label


In [6]:
start_time = time.time()
predict = onnx_model_session.run(
    [output_name],
    {input_name: x_test[0]}
)
print(predict)
end_time = time.time()
print(f"Inference Time: {end_time - start_time}")

[array([4], dtype=int64)]
Inference Time: 0.004395723342895508


## Use ONNX Model With Threads Management

In [7]:
session_options = ort.SessionOptions()
session_options.inter_op_num_threads = 2
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
onnx_model_session = ort.InferenceSession(
    "sklearn_to_onnx_model.onnx",
    sess_options=session_options,
    providers = ["CPUExecutionProvider"]
)
input_name = onnx_model_session.get_inputs()[0].name
output_name = onnx_model_session.get_outputs()[0].name
start_time = time.time()
predict = onnx_model_session.run(
    [output_name],
    {input_name: x_test[0]},
)
print(predict)
end_time = time.time()
print(f"Inference Time: {end_time - start_time}")

[array([4], dtype=int64)]
Inference Time: 0.0009095668792724609
