# Converting a MetaLearner to ONNX

<!-- mkdocs warning -->
<div class="admonition warning">
    <p class="admonition-title">Warning</p>
    <p style="margin-top: 0.6rem">This is a experimental feature which is not subject to deprecation cycles. Use
it at your own risk!
</p>
</div>

ONNX is an open standard for representing trained machine learning models.
By converting a Metalearner
into an ONNX model, it becomes easier to leverage the model in different environments without
needing to worry about compatibility or performance issues.

In particular, this conversion also allows models to be run on a variety of hardware setups. Also, ONNX
models are optimized for efficient computation, enabling faster inference compared to
the Python interface.

For more information about ONNX, you can check the ONNX [website](https://onnx.ai/).

In this example we will show how most MetaLearners can be converted to ONNX.

## Installation

In order to convert a MetaLearner to ONNX, we first need to install the following packages:

* [onnx](https://github.com/onnx/onnx)
* [onnxmltools](https://github.com/onnx/onnxmltools)
* [onnxruntime](https://github.com/microsoft/onnxruntime)
* [spox](https://github.com/Quantco/spox)

We can do so either via conda and conda-forge:

```console
$ conda install onnx onnxmltools onnxruntime spox -c conda-forge
```

or via pip and PyPI

```console
$ pip install onnx onnxmltools onnxruntime spox
```

## Usage

<!-- mkdocs warning -->
<div class="admonition warning">
    <p class="admonition-title">Warning</p>
    <p style="margin-top: 0.6rem">It is important to notice that this method only works for <code>TLearner</code>,
<code>XLearner</code>, <code>RLearner</code> and <code>DRLearner</code>.

Converting an `SLearner` is highly dependent on the fact that the base
model supports categorical variables or not and it is not implemented yet. 
</p>
</div>

### Loading the data

Just like in our [example on estimating CATEs with a MetaLearner](../example_basic/),
we will first load some experiment data:

In [1]:
import pandas as pd
from pathlib import Path
from git_root import git_root

df = pd.read_csv(git_root("data/learning_mindset.zip"))
outcome_column = "achievement_score"
treatment_column = "intervention"
feature_columns = [
    column for column in df.columns if column not in [outcome_column, treatment_column]
]
categorical_feature_columns = [
    "ethnicity",
    "gender",
    "frst_in_family",
    "school_urbanicity",
    "schoolid",
]
# Note that explicitly setting the dtype of these features to category
# allows both lightgbm as well as shap plots to
# 1. Operate on features which are not of type int, bool or float
# 2. Correctly interpret categoricals with int values to be
#    interpreted as categoricals, as compared to ordinals/numericals.
for categorical_feature_column in categorical_feature_columns:
    df[categorical_feature_column] = df[categorical_feature_column].astype("category")

Now that we've loaded the experiment data, we can train a MetaLearner.


### Training a MetaLearner

Again, mirroring our [example on estimating CATEs with a MetaLearner](../example_basic/), we can train an
`XLearner` as follows:

In [2]:
from metalearners import XLearner
from lightgbm import LGBMRegressor, LGBMClassifier

xlearner = XLearner(
    nuisance_model_factory=LGBMRegressor,
    propensity_model_factory=LGBMClassifier,
    treatment_model_factory=LGBMRegressor,
    is_classification=False,
    n_variants=2,
    nuisance_model_params={"n_estimators": 5, "verbose": -1},
    propensity_model_params={"n_estimators": 5, "verbose": -1},
    treatment_model_params={"n_estimators": 5, "verbose": -1},
    n_folds=2,
)

xlearner.fit(
    X=df[feature_columns],
    y=df[outcome_column],
    w=df[treatment_column],
)

<metalearners.xlearner.XLearner at 0x16a753050>

<!-- mkdocs note -->
<div class="admonition note">
    <p class="admonition-title">Note</p>
    <p style="margin-top: 0.6rem">In this example, we used all <code>lightgbm</code> models because these are the only type of models
that we managed to get to work with categorical encodings from <code>pandas</code>
while also being convertible to ONNX. Other <code>sklearn</code> models which support categoricals such as
<code>HistGradientBoostingRegressor</code> or <code>xgboost</code> models do not have support for them
in their conversion to ONNX. See <a href="https://github.com/onnx/sklearn-onnx/issues/1051" target="_blank">this issue</a>
and <a href="https://github.com/onnx/onnxmltools/issues/469#issuecomment-1993880910" target="_blank">this comment</a>.
</p>
</div>

### Converting the base models to ONNX

Before being able to convert the MetaLearner to ONXX we need to manually convert the necessary
base models for the prediction. To get the necessary base models that need to be
converted we can use <a href="../../api_documentation/#metalearners.metalearner.MetaLearner._necessary_onnx_models"><code>_necessary_onnx_models</code></a>.

In [3]:
necessary_models = xlearner._necessary_onnx_models()
necessary_models

{'propensity_model': [LGBMClassifier(n_estimators=5, verbose=-1)],
 'control_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)],
 'treatment_effect_model': [LGBMRegressor(n_estimators=5, verbose=-1)]}

We see that we need to convert the `"propensity_model"`, the `"control_effect_model"`
and the `"treatment_effect_model"`. We can do this with the following code where we
use the `convert_lightgbm` function from the `onnxmltools` package.


<!-- mkdocs note -->
<div class="admonition note">
    <p class="admonition-title">Note</p>
    <p style="margin-top: 0.6rem">It is important to know that for classifiers we need to pass the <code>zipmap=False</code> option. This
is required so the output probabilities are a Matrix and not a list of dictionaries.
In the case of using a <code>sklearn</code> model and using the <code>convert_sklearn</code> function, this
option needs to be specified with the <code>options={"zipmap": False}</code> parameter.
</p>
</div>



In [4]:
import onnx
from onnxmltools import convert_lightgbm
from onnxconverter_common.data_types import FloatTensorType

onnx_models: dict[str, list[onnx.ModelProto]] = {}

for model_kind, models in necessary_models.items():
    onnx_models[model_kind] = []
    for model in models:
        onnx_models[model_kind].append(
            convert_lightgbm(
                model,
                initial_types=[("X", FloatTensorType([None, len(feature_columns)]))],
                zipmap=False,
            )
        )

The maximum opset needed by this model is only 9.
The maximum opset needed by this model is only 8.
The maximum opset needed by this model is only 8.


Now we can call <a href="../../api_documentation/#metalearners.metalearner.MetaLearner._build_onnx"><code>_build_onnx</code></a> which combines
the the converted ONNX base models into a single ONNX model.
This combined model has a single 2D input ``"X"`` and a single output named ``"tau"``.
The output name can be changed using the ``output_name`` parameter. 

In [5]:
onnx_model = xlearner._build_onnx(onnx_models)

_build_onnx is an experimental feature. Use it at your own risk!


We can explore the input and output of the model and see that it expects a matrix with 11
columns and returns a three dimensional tensor with shape ``(..., 1, 1)`` which is expected
as there is only two treatment variants and one outcome as it is a regression problem.

In [6]:
print("ONNX model input: ", onnx_model.graph.input)
print("ONNX model output: ", onnx_model.graph.output)

ONNX model input:  [name: "X"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
      }
      dim {
        dim_value: 11
      }
    }
  }
}
]
ONNX model output:  [name: "tau"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
      }
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 1
      }
    }
  }
}
]


We can also visualize the ONNX model with, e.g. [netron](https://netron.app/):
<img src="../../imgs/onnx_netron.png" alt="Visualized ONNX model">


<!-- mkdocs note -->
<div class="admonition note">
    <p class="admonition-title">Note</p>
    <p style="margin-top: 0.6rem">We noticed that <code>convert_lightgbm</code> does not support using native pandas categorical variables.
This is because a numpy array needs to be passed when predicting, for this reason we need to
use the categories codes in the input matrix. For more context on this issue see
<a href="https://github.com/onnx/onnxmltools/issues/309" target="_blank">here</a> and 
<a href="https://github.com/microsoft/LightGBM/issues/5162" target="_blank">here</a>.
</p>
</div>


In [7]:
import numpy as np

X_onnx = df[feature_columns].copy(deep=True)
for c in categorical_feature_columns:
    X_onnx[c] = df[c].cat.codes
X_onnx = X_onnx.to_numpy(np.float32)

We can finally use ``onnxruntime`` to perform predictions using our model:

In [8]:
import onnxruntime as rt

sess = rt.InferenceSession(
    onnx_model.SerializeToString(), providers=rt.get_available_providers()
)

(pred_onnx,) = sess.run(
    ["tau"],
    {"X": X_onnx},
)

In [9]:
onnx.save_model(onnx_model, "model.onnx")

We recommend always doing a final check with some data that the CATEs predicted by the python
implementation and the ONNX model are the same (up to some tolerance). This can be done with
the following code:

<div class="admonition note">
    <p class="admonition-title">Note</p>
    <p style="margin-top: 0.6rem">We have to use the data as if it was out-of-sample with <code>oos_method = True</code> as when we
converted the base models we used the <code>_overall_estimtor</code>.
</p>
</div>


In [10]:
np.testing.assert_allclose(
    xlearner.predict(df[feature_columns], True, "overall"), pred_onnx, atol=1e-6
)

## Further comments

* It would be desirable to work with ``DoubleTensorType`` instead of ``FloatTensorType``
  but we have noted that some converters have issues with it. We recommend try using
  ``DoubleTensorType`` but switching to ``FloatTensorType`` in case the converter fails.
* In the case the final assertion fails we recommend first testing that the different
  base models have the same base outputs as we discovered some issues with some converters,
  see [this issue](https://github.com/onnx/sklearn-onnx/issues/1117) and
  [this issue](https://github.com/onnx/sklearn-onnx/issues/1116).