## Training and saving a model in ONNX format

In [1]:
# Train a model
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

# Load iris_dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [2]:
# Fit a random forest classifier
clr = RandomForestClassifier(random_state=42)
clr.fit(X_train, y_train)

In [3]:
# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

In [4]:
# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy

sess = rt.InferenceSession("rf_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onnx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]

In [5]:
pred_onnx

array([1, 0, 2, 1, 1, 0, 1, 2, 1, 1, 2, 0, 0, 0, 0, 1, 2, 1, 1, 2, 0, 2,
       0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0], dtype=int64)

In [6]:
# Compute prediction on data from a CSV file
import pandas as pd

df = pd.read_csv("score_input_data.csv")

In [7]:
df.head()

Unnamed: 0,x1,x2,x3,x4
0,4.9,2.4,3.3,1.0
1,4.6,3.6,1.0,0.2
2,5.4,3.4,1.5,0.4
3,5.8,2.7,5.1,1.9
4,6.3,3.3,4.7,1.6


In [8]:
# Turn data into numpy array
df_array = numpy.array(df)

In [9]:
# Predict using trained model
pred_onnx = sess.run([label_name], {input_name: df_array.astype(numpy.float32)})[0]
out_df = pd.DataFrame(columns=["id", "prediction"])
out_df["id"] = range(len(df_array))
out_df["prediction"] = pred_onnx

In [10]:
# Predictions as a list of dictionaries
out_df.to_dict(orient="records")

[{'id': 0, 'prediction': 1},
 {'id': 1, 'prediction': 0},
 {'id': 2, 'prediction': 0},
 {'id': 3, 'prediction': 2},
 {'id': 4, 'prediction': 1},
 {'id': 5, 'prediction': 0},
 {'id': 6, 'prediction': 0},
 {'id': 7, 'prediction': 1},
 {'id': 8, 'prediction': 2},
 {'id': 9, 'prediction': 2},
 {'id': 10, 'prediction': 2},
 {'id': 11, 'prediction': 1},
 {'id': 12, 'prediction': 0},
 {'id': 13, 'prediction': 1},
 {'id': 14, 'prediction': 1},
 {'id': 15, 'prediction': 2},
 {'id': 16, 'prediction': 1},
 {'id': 17, 'prediction': 0},
 {'id': 18, 'prediction': 2},
 {'id': 19, 'prediction': 0},
 {'id': 20, 'prediction': 2},
 {'id': 21, 'prediction': 0},
 {'id': 22, 'prediction': 0},
 {'id': 23, 'prediction': 0},
 {'id': 24, 'prediction': 1},
 {'id': 25, 'prediction': 1},
 {'id': 26, 'prediction': 0},
 {'id': 27, 'prediction': 1},
 {'id': 28, 'prediction': 1},
 {'id': 29, 'prediction': 2},
 {'id': 30, 'prediction': 2},
 {'id': 31, 'prediction': 2},
 {'id': 32, 'prediction': 1},
 {'id': 33, 'predict