#### 参考博客：https://zhuanlan.zhihu.com/p/86867138  

#### 1、为了更直观的了解ONNX格式内容，下面，我们训练一个简单的Logistic Regression模型，然后导出ONNX。仍然使用常用的分类数据集iris：

In [1]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

clr = LogisticRegression()
clr.fit(X_train, y_train)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression()

#### 2、使用skl2onnx把Scikit-learn模型序列化为ONNX格式：

In [2]:
# from skl2onnx import convert_sklearn
# from skl2onnx.common.data_types import FloatTensorType

# initial_type = [('float_input', FloatTensorType([1, 4]))]
# onx = convert_sklearn(clr, initial_types=initial_type)
# with open("logreg_iris.onnx", "wb") as f:
#     f.write(onx.SerializeToString())
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())    
    
# from skl2onnx import convert_sklearn
# from skl2onnx.common.data_types import FloatTensorType
# initial_type = [('float_input', FloatTensorType([None, 4]))]
# onx = convert_sklearn(clr, initial_types=initial_type)
# with open("rf_iris.onnx", "wb") as f:
#     f.write(onx.SerializeToString())

#### 3、使用ONNX Python API查看和验证模型：

In [3]:
import onnx

model = onnx.load('logreg_iris.onnx')
print(model)

ir_version: 4
producer_name: "skl2onnx"
producer_version: "1.6.1"
domain: "ai.onnx"
model_version: 0
doc_string: ""
graph {
  node {
    input: "float_input"
    output: "label"
    output: "probability_tensor"
    name: "LinearClassifier"
    op_type: "LinearClassifier"
    attribute {
      name: "classlabels_ints"
      ints: 0
      ints: 1
      ints: 2
      type: INTS
    }
    attribute {
      name: "coefficients"
      floats: 0.34892532229423523
      floats: 1.400669813156128
      floats: -2.1071150302886963
      floats: -0.9723743200302124
      floats: 0.43998441100120544
      floats: -1.5219184160232544
      floats: 0.4314039945602417
      floats: -0.9461677074432373
      floats: -1.639839768409729
      floats: -1.4062713384628296
      floats: 2.3834683895111084
      floats: 2.0455572605133057
      type: FLOATS
    }
    attribute {
      name: "intercepts"
      floats: 0.2185143083333969
      floats: 0.9689608812332153
      floats: -0.9929876327514648
     

#### 4、下面我们使用ONNX Runtime Python API预测该ONNX模型，当前仅使用了测试数据集中的第一条数据：

In [12]:
# import onnxruntime as rt
# import numpy
# sess = rt.InferenceSession("logreg_iris.onnx")
# input_name = sess.get_inputs()[0].name
# label_name = sess.get_outputs()[0].name
# probability_name = sess.get_outputs()[1].name
# pred_onx = sess.run([label_name, probability_name], {input_name: X_test[0].astype(numpy.float32)})

# # print info
# print('input_name: ' + input_name)
# print('label_name: ' + label_name)
# print('probability_name: ' + probability_name)
# print(X_test[0])
# print(pred_onx)
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("logreg_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
probability_name = sess.get_outputs()[1].name
pred_onx = sess.run([label_name, probability_name], {input_name: X_test.astype(numpy.float32)})[0]

# # print info
print('input_name: ' + input_name)
print('label_name: ' + label_name)
print('probability_name: ' + probability_name)
print(X_test[0])
print(pred_onx)

input_name: float_input
label_name: output_label
probability_name: output_probability
[5.  3.6 1.4 0.2]
[0 1 2 2 0 1 0 0 2 0 0 2 2 0 1 2 2 2 0 0 0 0 2 2 1 2 0 1 1 1 2 2 2 0 2 2 2
 0]
