In [2]:
import numpy as np
import pandas as pd

pd.set_option('display.max_columns', None)

pd.set_option('display.max_rows', 10)

In [37]:
import warnings
import logging
import os
from azureml.automl.core.onnx_convert import OnnxInferenceHelper
import json
import time

In [21]:
from azureml.automl.core.onnx_convert import OnnxInferenceHelper
from typing import Any, Tuple
from numpy import ndarray


class OnnxModelWrapper:
    """
        helper class for prediction when using onnx model
    """
    def __init__(self, onnx_model_bytes: bytes, onnx_input_map: dict):
        """
        :param onnx_model_bytes: the onnx model in bytes
        :param onnx_input_map: the onnx_resource dictionary
        """
        self.onnx_model_bytes = onnx_model_bytes
        self.onnx_input_map = onnx_input_map
        self.wrapper_model = OnnxInferenceHelper(self.onnx_model_bytes, self.onnx_input_map)

    def predict(self, X) -> Tuple[Any, Any]:
        """
        predict by using OnnxInferenceHelper
        :param X: features to predict
        :returns tuple of <label, prob>
        """
        return self.wrapper_model.predict(X)

    def predict_proba(self, X) -> ndarray:
        """
        predict proba by using OnnxInferenceHelper
        :param X: features to predict
        :returns ndarray of prob
        """
        _, y_prob = self.wrapper_model.predict(X, with_prob=True)
        return y_prob


In [22]:
onnx_model_file = open('onnx.model', 'rb')
onnx_res_file = open('onnx.res', 'r')
onnx_model_data = onnx_model_file.read()
onnx_res_data = onnx_res_file.read()

In [23]:
onnxrt_wrapper = OnnxModelWrapper(onnx_model_data, json.loads(onnx_res_data))

In [24]:
filepath = 'invoice.csv'
data_df = pd.read_csv(filepath)

In [43]:
data_df.head()

Unnamed: 0,INVOICECUSTTRANSRECID,WASDISPUTED_VALUE,ROW_UNIQUEKEY,ISCLOSED_VALUE,TRANSTYPE,PAYMENTSCHEDULEID,CASHDISCOUNTCODE,RECID,ACCOUNTINGCURRENCY,WASCOLLECTIONLETTERSENT,...,InvoiceAccount.MEAN(Invoice.ROW_UNIQUEKEY),InvoiceAccount.MEAN(Invoice.ISCLOSED_VALUE),InvoiceAccount.MEAN(Invoice.RECID),InvoiceAccount.MEAN(Invoice.CASHDISCOUTPERCENT),InvoiceAccount.MEAN(Invoice.WASCOLLECTIONLETTERSENT_VALUE),InvoiceAccount.MEAN(Invoice.CUSTOMERRECID),InvoiceAccount.MEAN(Invoice.INVOICEAMOUNTACCOUNTING),InvoiceAccount.MEAN(Invoice.TRANSTYPE_VALUE),InvoiceAccount.COUNT(Invoice),InvoiceAccount.NUM_CHARACTERS(CUSTOMERID)
0,5637144576,0,720,1,5,1,5,5637144576,11,0,...,2986.29,1.0,5637145550,0.0,0.0,22565421558,319015.35,2.0,24,6
1,5637144577,0,740,1,5,1,5,5637144577,11,0,...,3423.28,0.9,15372649960,0.0,0.14,22565421559,395029.4,3.21,29,6
2,5637144578,0,763,1,5,1,5,5637144578,11,0,...,3307.62,1.0,5637145552,0.0,0.0,22565421560,301095.21,2.0,24,6
3,5637144579,0,782,1,5,1,5,5637144579,11,0,...,4385.5,0.93,12395974712,0.0,0.07,22565421562,182984.92,2.82,28,6
4,5637144580,0,802,1,5,1,5,5637144580,11,0,...,4533.59,0.96,10309914194,0.0,0.0,22565421563,80743.67,2.44,27,6


In [44]:
data_df.shape

(14917, 67)

In [45]:
start = time.time()
predictions = onnxrt_wrapper.predict_proba(data_df)
end = time.time()
print(end - start)

33.04594302177429


In [46]:
from azureml.explain.model.mimic.mimic_explainer import MimicExplainer
from azureml.explain.model.mimic.models.lightgbm_model import LGBMExplainableModel
start = time.time()
explainer = MimicExplainer(onnxrt_wrapper, data_df, LGBMExplainableModel, augment_data=False)
end = time.time()
print(end - start)

35.38013672828674


In [49]:
def explain(train_data):
    explanation = explainer.explain_global(train_data, include_local=False)

In [50]:
sample = data_df.sample(1000)
start = time.time()
explain(sample)
end = time.time()
print(end - start)

52.213995695114136


In [51]:
sample = data_df.sample(2000)
start = time.time()
explain(sample)
end = time.time()
print(end - start)

97.92370271682739


In [52]:
sample = data_df.sample(5000)
start = time.time()
explain(sample)
end = time.time()
print(end - start)

239.50455570220947


In [54]:
start = time.time()
explain(data_df)
end = time.time()
print(end - start)

689.9257900714874


In [55]:
data_df.shape

(14917, 67)