In [None]:
import numpy as np
import onnx
import onnxruntime as rt
import pandas as pd
import plotly.express as px
import sklearn.model_selection as skm
from joblib import parallel_backend
from lightgbm import LGBMClassifier
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm,
)
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.algebra.onnx_ops import OnnxConcat, OnnxCos, OnnxMul, OnnxSin
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,
)
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    f1_score,
    fbeta_score,
    make_scorer,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_curve,
)
from sklearn.model_selection import (
    RandomizedSearchCV,
    StratifiedKFold,
    cross_val_predict,
    cross_val_score,
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

In [2]:
# NUM_POSTE   : numéro Météo-France du poste sur 8 chiffres
# NOM_USUEL   : nom usuel du poste
# LAT         : latitude, négative au sud (en degrés et millionièmes de degré)
# LON         : longitude, négative à l’ouest de GREENWICH (en degrés et millionièmes de degré)
# ALTI        : altitude du pied de l'abri ou du pluviomètre si pas d'abri (en m)
# AAAAMMJJ    : date de la mesure (année mois jour)
# RR          : quantité de précipitation tombée en 24 heures (de 06h FU le jour J à 06h FU le jour J+1). La valeur relevée à J+1 est affectée au jour J (en mm et 1/10)
# TN          : température minimale sous abri (en °C et 1/10)
# HTN         : heure de TN (hhmm)
# TX          : température maximale sous abri (en °C et 1/10)
# HTX         : heure de TX (hhmm)
# TM          : moyenne quotidienne des températures horaires sous abri (en °C et 1/10)
# TNTXM       : moyenne quotidienne (TN+TX)/2 (en °C et 1/10)
# TAMPLI      : amplitude thermique quotidienne : écart entre TX et TN quotidiens (TX-TN) (en °C et 1/10)
# TNSOL       : température quotidienne minimale à 10 cm au-dessus du sol (en °C et 1/10)
# TN50        : température quotidienne minimale à 50 cm au-dessus du sol (en °C et 1/10)
# DG          : durée de gel sous abri (T ≤ 0°C) (en mn)
# FFM         : moyenne quotidienne de la force du vent moyenné sur 10 mn, à 10 m (en m/s et 1/10)
# FF2M        : moyenne quotidienne de la force du vent moyenné sur 10 mn, à 2 m (en m/s et 1/10)
# FXY         : maximum quotidien de la force maximale horaire du vent moyenné sur 10 mn, à 10 m (en m/s et 1/10)
# DXY         : direction de FXY (en rose de 360)
# HXY         : heure de FXY (hhmm)
# FXI         : maximum quotidien de la force maximale horaire du vent instantané, à 10 m (en m/s et 1/10)
# DXI         : direction de FXI (en rose de 360)
# HXI         : heure de FXI (hhmm)
# FXI2        : maximum quotidien de la force maximale horaire du vent instantané, à 2 m (en m/s et 1/10)
# DXI2        : direction de FXI2 (en rose de 360)
# HXI2        : heure de FXI2 (hhmm)
# FXI3S       : maximum quotidien de la force maximale horaire du vent moyenné sur 3 s, à 10 m (en m/s et 1/10)
# DXI3S       : direction de FXI3S (en rose de 360)
# HXI3S       : heure de FXI3S (hhmm)
# DRR         : durée des précipitations (en mn)

# A chaque donnée est associé un code qualité (ex: T;QT) :
#  9 : donnée filtrée (la donnée a passé les filtres/contrôles de premiers niveaux)
#  0 : donnée protégée (la donnée a été validée définitivement par le climatologue)
#  1 : donnée validée (la donnée a été validée par contrôle automatique ou par le climatologue)
#  2 : donnée douteuse en cours de vérification (la donnée a été mise en doute par contrôle automatique)

# D'une façon générale, les valeurs fournies sont données avec une précision qui correspond globalement à la résolution de l'appareil de mesure de la valeur.
# Toutefois, il peut arriver, pour des raisons techniques de stokage ou d'extraction des valeurs, que cette règle ne soit pas respectée.
# Du fait d'arrondis, il peut ponctuellement arriver que des valeurs de base à un pas de temps inférieur (par exemple données minutes) ne soient pas exactement cohérentes avec leurs correspondants sur un pas de temps supérieur (par exemple données horaires).

In [3]:
df = pd.read_feather("../data/merged_meteo_red_days_from_20170101.feather")
print(df.columns)
df.head()

Index(['AAAAMMJJ', 'TN_BORDEAUX', 'TN_LILLE', 'TN_LYON', 'TN_MARSEILLE',
       'TN_MONTPELLIER', 'TN_NANTES', 'TN_NICE', 'TN_PARIS', 'TN_REIMS',
       'TN_RENNES', 'TN_STRASBOURG', 'TN_TOULON', 'TN_TOULOUSE', 'TX_BORDEAUX',
       'TX_LILLE', 'TX_LYON', 'TX_MARSEILLE', 'TX_MONTPELLIER', 'TX_NANTES',
       'TX_NICE', 'TX_PARIS', 'TX_REIMS', 'TX_RENNES', 'TX_STRASBOURG',
       'TX_TOULON', 'TX_TOULOUSE', 'TNTXM_BORDEAUX', 'TNTXM_LILLE',
       'TNTXM_LYON', 'TNTXM_MARSEILLE', 'TNTXM_MONTPELLIER', 'TNTXM_NANTES',
       'TNTXM_NICE', 'TNTXM_PARIS', 'TNTXM_REIMS', 'TNTXM_RENNES',
       'TNTXM_STRASBOURG', 'TNTXM_TOULON', 'TNTXM_TOULOUSE', 'TAMPLI_BORDEAUX',
       'TAMPLI_LILLE', 'TAMPLI_LYON', 'TAMPLI_MARSEILLE', 'TAMPLI_MONTPELLIER',
       'TAMPLI_NANTES', 'TAMPLI_NICE', 'TAMPLI_PARIS', 'TAMPLI_REIMS',
       'TAMPLI_RENNES', 'TAMPLI_STRASBOURG', 'TAMPLI_TOULON',
       'TAMPLI_TOULOUSE', 'is_red_day'],
      dtype='object')


Unnamed: 0,AAAAMMJJ,TN_BORDEAUX,TN_LILLE,TN_LYON,TN_MARSEILLE,TN_MONTPELLIER,TN_NANTES,TN_NICE,TN_PARIS,TN_REIMS,...,TAMPLI_MONTPELLIER,TAMPLI_NANTES,TAMPLI_NICE,TAMPLI_PARIS,TAMPLI_REIMS,TAMPLI_RENNES,TAMPLI_STRASBOURG,TAMPLI_TOULON,TAMPLI_TOULOUSE,is_red_day
0,20170101,2.6,-4.3,-1.2,4.5,2.4,-4.0,4.6,-4.3,-4.4,...,9.9,9.2,10.3,5.7,4.2,7.3,2.2,10.5,4.8,False
1,20170102,2.2,-1.0,-0.5,6.3,3.3,3.8,5.6,-0.6,-2.1,...,9.0,2.0,10.5,3.4,2.9,3.6,1.8,6.6,1.9,False
2,20170103,1.3,-2.6,-0.8,-0.1,0.3,0.0,5.1,0.9,-1.0,...,11.3,3.8,9.7,2.4,3.2,4.5,3.7,10.2,2.1,True
3,20170104,-4.9,2.1,-2.0,1.5,-2.8,-3.2,5.4,0.1,0.5,...,14.4,9.3,7.8,6.5,4.8,9.3,5.0,11.1,3.8,True
4,20170105,-2.8,0.3,0.3,2.4,1.6,0.1,4.4,2.0,0.4,...,8.5,7.6,8.6,5.1,4.4,7.2,4.7,7.1,7.9,True


In [4]:
# Remove days after 20240311
df = df[df["AAAAMMJJ"] < 20240401]
df.tail()

Unnamed: 0,AAAAMMJJ,TN_BORDEAUX,TN_LILLE,TN_LYON,TN_MARSEILLE,TN_MONTPELLIER,TN_NANTES,TN_NICE,TN_PARIS,TN_REIMS,...,TAMPLI_MONTPELLIER,TAMPLI_NANTES,TAMPLI_NICE,TAMPLI_PARIS,TAMPLI_REIMS,TAMPLI_RENNES,TAMPLI_STRASBOURG,TAMPLI_TOULON,TAMPLI_TOULOUSE,is_red_day
2622,20240307,5.7,3.4,0.3,0.8,0.5,4.0,7.4,3.8,-0.6,...,12.5,10.2,7.2,10.4,13.7,11.0,5.5,13.4,10.7,True
2623,20240308,8.9,2.4,4.8,7.0,8.7,6.9,6.9,3.6,-1.0,...,4.5,5.2,6.4,10.3,15.2,6.5,12.3,5.5,5.6,False
2624,20240309,7.5,5.2,5.9,7.7,9.9,6.2,8.2,6.4,2.0,...,3.5,3.7,4.7,8.1,12.7,4.5,12.3,4.4,6.3,False
2625,20240310,3.3,7.6,5.7,9.2,6.3,0.9,8.4,7.1,5.5,...,8.8,11.9,4.9,4.3,5.5,13.1,5.3,5.0,12.9,False
2626,20240311,4.5,7.6,5.1,6.7,7.6,3.9,8.1,6.4,0.2,...,7.6,10.1,5.4,6.6,13.7,13.9,10.5,8.4,4.3,True


In [5]:
# Count nan values and show nan cols:

print(df.isna().sum())
pd.set_option("display.max_rows", None)
df.columns[df.isna().sum() > 0]

AAAAMMJJ                0
TN_BORDEAUX             0
TN_LILLE                0
TN_LYON                 0
TN_MARSEILLE            0
TN_MONTPELLIER        365
TN_NANTES               0
TN_NICE                 0
TN_PARIS                0
TN_REIMS                0
TN_RENNES               0
TN_STRASBOURG           0
TN_TOULON               0
TN_TOULOUSE             0
TX_BORDEAUX             0
TX_LILLE                0
TX_LYON                 0
TX_MARSEILLE            0
TX_MONTPELLIER        365
TX_NANTES               0
TX_NICE                 0
TX_PARIS                0
TX_REIMS                0
TX_RENNES               0
TX_STRASBOURG           0
TX_TOULON               0
TX_TOULOUSE             0
TNTXM_BORDEAUX          0
TNTXM_LILLE             0
TNTXM_LYON              0
TNTXM_MARSEILLE         0
TNTXM_MONTPELLIER     365
TNTXM_NANTES            0
TNTXM_NICE              0
TNTXM_PARIS             0
TNTXM_REIMS             0
TNTXM_RENNES            0
TNTXM_STRASBOURG        0
TNTXM_TOULON

Index(['TN_MONTPELLIER', 'TX_MONTPELLIER', 'TNTXM_MONTPELLIER',
       'TAMPLI_MONTPELLIER'],
      dtype='object')

In [6]:
# drop montepellier etc
cities_to_drop = ["MONTPELLIER", "REIMS", "RENNES", "NICE"]


df.drop(
    columns=[c for c in df.columns if c.split("_")[-1] in cities_to_drop],
    axis=1,
    inplace=True,
)
assert df.isna().sum().sum() == 0, "There are still NaN values in the dataframe"

Feature engineering

In [558]:
# Add a is week day feature:
daydt = pd.to_datetime(df["AAAAMMJJ"], format="%Y%m%d").dt
df["is_week_day"] = (daydt.dayofweek < 5).astype(bool)

# Red days in last week feature
df["red_days_last_week"] = (
    df["is_red_day"]
    .rolling(window=7, min_periods=1)
    .sum()
    .shift(1)
    .fillna(0)
    .astype(bool)
)

# Month feature
# category encoding is the simplest and best performing method but fails with onnx export
# df["month"] = daydt.month.astype("category")
# so we use sin cos embedding after instead which provides similar performances
df["month"] = daydt.month.astype(float)

# Target as int
df["is_red_day"] = df["is_red_day"].astype(int)


class SinCosTransformer(FunctionTransformer):
    """
    Transformer that applies sin and cos transformation to a "month" column in the input DataFrame.
    It converts the month into two features: sin(2π * month / 12) and cos(2π * month / 12),
    while keeping all other columns unchanged.
    We do it in a custom transformer to be able to export it to ONNX later.
    """

    def __init__(self):
        super().__init__(func=self._sin_cos_transform, validate=False)

    def _sin_cos_transform(self, X):
        month = X["month"]
        other_cols = X.drop(columns=["month"]).values

        two_pi_month_norm = 2 * np.pi * month / 12
        sin_feat = np.sin(two_pi_month_norm).astype(np.float32)
        cos_feat = np.cos(two_pi_month_norm).astype(np.float32)

        return np.column_stack(
            [
                sin_feat,
                cos_feat,
                other_cols,
            ]
        ).astype(np.float32)


In [507]:
df.columns

Index(['AAAAMMJJ', 'TN_BORDEAUX', 'TN_LILLE', 'TN_LYON', 'TN_MARSEILLE',
       'TN_NANTES', 'TN_PARIS', 'TN_STRASBOURG', 'TN_TOULON', 'TN_TOULOUSE',
       'TX_BORDEAUX', 'TX_LILLE', 'TX_LYON', 'TX_MARSEILLE', 'TX_NANTES',
       'TX_PARIS', 'TX_STRASBOURG', 'TX_TOULON', 'TX_TOULOUSE',
       'TNTXM_BORDEAUX', 'TNTXM_LILLE', 'TNTXM_LYON', 'TNTXM_MARSEILLE',
       'TNTXM_NANTES', 'TNTXM_PARIS', 'TNTXM_STRASBOURG', 'TNTXM_TOULON',
       'TNTXM_TOULOUSE', 'TAMPLI_BORDEAUX', 'TAMPLI_LILLE', 'TAMPLI_LYON',
       'TAMPLI_MARSEILLE', 'TAMPLI_NANTES', 'TAMPLI_PARIS',
       'TAMPLI_STRASBOURG', 'TAMPLI_TOULON', 'TAMPLI_TOULOUSE', 'is_red_day',
       'is_week_day', 'red_days_last_week', 'month'],
      dtype='object')

In [508]:
# Remove data between  01/04 and 01/11 and prior to 2017-04
df = df[~((daydt.month >= 4) & (daydt.month <= 10))]
df = df[~((daydt.year == 2017) & (daydt.month < 4))]

Learning

In [509]:
#  show values count in y
df["is_red_day"].value_counts()

is_red_day
0    892
1    147
Name: count, dtype: int64

In [510]:
def confusion_table(
    y_true: np.ndarray | list, y_pred: np.ndarray | list
) -> pd.DataFrame:
    """Rows: Actual, Columns: Predicted. Aligns inputs by position (not index)."""
    y_true_series = pd.Series(np.asarray(y_true), name="Actual")
    y_pred_series = pd.Series(np.asarray(y_pred), name="Predicted")

    if len(y_true_series) != len(y_pred_series):
        raise ValueError(
            f"y_true and y_pred must have same length, got {len(y_true_series)} and {len(y_pred_series)}"
        )

    ct = pd.crosstab(y_true_series, y_pred_series)
    return ct


def evaluate(model, X, y, plot_roc_rurve=False, threshold=0.5):
    y_probs = model.predict_proba(X)[:, 1]
    y_pred = (y_probs >= threshold).astype(int)

    print(f"Accuracy: {accuracy_score(y, y_pred)}")
    print("Precision", precision_score(y, y_pred))
    print("Recall", recall_score(y, y_pred))
    print("F1 score", f1_score(y, y_pred))
    print("F2 score", fbeta_score(y, y_pred, beta=2))
    print("Average precision (AUPRC)", average_precision_score(y, y_probs))

    display(confusion_table(y, y_pred))

    if plot_roc_rurve:
        y_pred_proba = model.predict_proba(X)[:, 1]
        fpr, tpr, _ = roc_curve(y, y_pred_proba)
        roc_auc = auc(fpr, tpr)
        fig = px.area(
            x=fpr,
            y=tpr,
            title=f"ROC Curve (AUC={roc_auc:.4f})",
            labels=dict(x="False Positive Rate", y="True Positive Rate"),
            width=700,
            height=500,
        )
        fig.show()

In [511]:
X = df.drop(["is_red_day", "AAAAMMJJ"], axis=1)
todrop = []
for c in X.columns:
    if "TAMPLI" in c or "RR" in c or "FFM" in c or "TM" in c or "TNTXM" in c:
        todrop.append(c)


X = X.drop(todrop, axis=1)
y = df["is_red_day"]


X_train, X_test, y_train, y_test = skm.train_test_split(
    X, y, test_size=0.2, random_state=42
)
f2_scorer = make_scorer(fbeta_score, beta=2)

In [512]:
X.columns

Index(['TN_BORDEAUX', 'TN_LILLE', 'TN_LYON', 'TN_MARSEILLE', 'TN_NANTES',
       'TN_PARIS', 'TN_STRASBOURG', 'TN_TOULON', 'TN_TOULOUSE', 'TX_BORDEAUX',
       'TX_LILLE', 'TX_LYON', 'TX_MARSEILLE', 'TX_NANTES', 'TX_PARIS',
       'TX_STRASBOURG', 'TX_TOULON', 'TX_TOULOUSE', 'is_week_day',
       'red_days_last_week', 'month'],
      dtype='object')

In [565]:
# Define the parameter grid


SEED = 42
RUN_HP_TUNING = True

if RUN_HP_TUNING:
    param_prefix = "classifier__"
    param_grid = {
        f"{param_prefix}n_estimators": [100, 200, 500],
        f"{param_prefix}learning_rate": [0.01, 0.05],
        f"{param_prefix}max_depth": [3, 5, 10],
        f"{param_prefix}num_leaves": [15, 31],
        f"{param_prefix}min_child_samples": [20, 50, 100],
        f"{param_prefix}reg_alpha": [0.1, 1, 10],
        f"{param_prefix}reg_lambda": [0.1, 1, 10],
        f"{param_prefix}scale_pos_weight": [5, 10, 20],
    }

    # Create base LGBM classifier
    lgbm = LGBMClassifier(random_state=SEED, force_row_wise=True, verbose=-1, n_jobs=-1)
    # with preprocessor
    pipeline = Pipeline(
        [
            ("preprocessor", SinCosTransformer()),
            ("classifier", lgbm),
        ]
    )

    kf = StratifiedKFold(n_splits=3, shuffle=True, random_state=SEED)
    random_search = RandomizedSearchCV(
        pipeline,
        param_grid,
        scoring="average_precision",  # This is the PR AUC alias,
        n_iter=50,
        cv=kf,
        verbose=1,
        n_jobs=-1,
        random_state=SEED,
    )

    with parallel_backend("loky"):
        random_search.fit(X_train, y_train)

    # Print the best parameters and evaluate
    print(
        "Best parameters found: ",
        {k.replace(param_prefix, ""): v for k, v in random_search.best_params_.items()},
    )
    print("Best score:", random_search.best_score_)

Fitting 3 folds for each of 50 candidates, totalling 150 fits
Best parameters found:  {'scale_pos_weight': 5, 'reg_lambda': 1, 'reg_alpha': 10, 'num_leaves': 15, 'n_estimators': 100, 'min_child_samples': 50, 'max_depth': 3, 'learning_rate': 0.05}
Best score: 0.6299460239223746


In [566]:
# Obtained with random search above
best_params = {
    "scale_pos_weight": 5,
    "reg_lambda": 1,
    "reg_alpha": 10,
    "num_leaves": 15,
    "n_estimators": 100,
    "min_child_samples": 50,
    "max_depth": 3,
    "learning_rate": 0.05,
}

# Final model with calibrated classifier
lgbm = LGBMClassifier(random_state=0, force_row_wise=True, verbose=-1, **best_params)
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# We use cv to ensure data efficiency
calibrated_lgbm = CalibratedClassifierCV(
    estimator=lgbm,
    method="sigmoid",  # Platt scaling
    cv=kf,
    ensemble=True,
)

calibrated_lgbm = Pipeline(
    [
        ("preprocessor", SinCosTransformer()),
        ("classifier", calibrated_lgbm),
    ]
)

calibrated_lgbm.fit(X_train, y_train)

# print cv average_precision scores
cv_scores = cross_val_score(
    calibrated_lgbm,
    X_train,
    y_train,
    scoring="average_precision",
    cv=kf,
)
print("CV Average Precision scores:", cv_scores)
print("Mean CV Average Precision:", cv_scores.mean())

CV Average Precision scores: [0.62988328 0.69088745 0.6494513  0.64744277 0.69165021]
Mean CV Average Precision: 0.6618630032932529


In [567]:
def plot_cv_pr_curve(calibrated_lgbm, X_train, y_train, kf):
    """Plot Cross-Validated Precision-Recall Curve using OOF probabilities"""
    y_oof_probs = cross_val_predict(
        calibrated_lgbm, X_train, y_train, cv=kf, method="predict_proba"
    )[:, 1]
    precision, recall, thresholds = precision_recall_curve(y_train, y_oof_probs)
    pr_auc = average_precision_score(y_train, y_oof_probs)

    # We add a 1 to the end of thresholds to match the length of precision/recall
    thresholds_padded = np.append(thresholds, 1)
    pr_curve_df = pd.DataFrame(
        {"Recall": recall, "Precision": precision, "Threshold": thresholds_padded}
    )

    fig = px.area(
        pr_curve_df,
        x="Recall",
        y="Precision",
        title=f"Cross-Validated Precision-Recall Curve (PR-AUC = {pr_auc:.4f})",
        hover_data=["Threshold"],
        labels=dict(
            x="Recall (Ability to catch positives)",
            y="Precision (Accuracy of positive calls)",
        ),
        width=700,
        height=500,
    )

    fig.show()


print("Generating Out-of-Fold probabilities...")
plot_cv_pr_curve(calibrated_lgbm, X_train, y_train, kf)

Generating Out-of-Fold probabilities...


In [568]:
THRESHOLD = 0.2

print("Train set")
evaluate(calibrated_lgbm, X_train, y_train, threshold=THRESHOLD)

print("Test set")
evaluate(calibrated_lgbm, X_test, y_test, threshold=THRESHOLD)

Train set
Accuracy: 0.8724428399518652
Precision 0.5329949238578681
Recall 0.8823529411764706
F1 score 0.6645569620253164
F2 score 0.7800891530460624
Average precision (AUPRC) 0.78695696396704


Predicted,0,1
Actual,Unnamed: 1_level_1,Unnamed: 2_level_1
0,620,92
1,14,105


Test set
Accuracy: 0.8798076923076923
Precision 0.5319148936170213
Recall 0.8928571428571429
F1 score 0.6666666666666666
F2 score 0.7861635220125787
Average precision (AUPRC) 0.7312187177034881


Predicted,0,1
Actual,Unnamed: 1_level_1,Unnamed: 2_level_1
0,158,22
1,3,25


Explore incorrect preds

In [569]:
# Explore incorrect predictions
y_pred_proba = calibrated_lgbm.predict_proba(X_test)[:, 1]
X_test_indices = X_test.index
df_test = df.loc[X_test_indices].copy()
df_test["y_pred_proba"] = y_pred_proba
df_test["y_pred"] = (y_pred_proba >= THRESHOLD).astype(int)

incorrect_preds = df_test[df_test["is_red_day"] != df_test["y_pred"]]
print(len(incorrect_preds), "incorrect predictions")
assert all(
    incorrect_preds["y_pred"] != incorrect_preds["is_red_day"]
), "y_pred should be different from is_red_day"
incorrect_preds.head(10)

25 incorrect predictions


Unnamed: 0,AAAAMMJJ,TN_BORDEAUX,TN_LILLE,TN_LYON,TN_MARSEILLE,TN_NANTES,TN_PARIS,TN_STRASBOURG,TN_TOULON,TN_TOULOUSE,...,TAMPLI_PARIS,TAMPLI_STRASBOURG,TAMPLI_TOULON,TAMPLI_TOULOUSE,is_red_day,is_week_day,red_days_last_week,month,y_pred_proba,y_pred
2257,20230308,7.9,0.8,1.9,4.7,7.6,5.3,2.9,6.4,3.6,...,9.8,12.7,10.8,17.5,1,True,True,3.0,0.151464,0
807,20190319,5.8,0.0,1.6,3.4,3.5,4.1,-0.3,3.4,5.7,...,9.5,12.1,11.7,5.7,1,True,False,3.0,0.129515,0
1853,20220128,-0.9,0.2,-2.7,-1.4,2.1,6.3,0.2,1.0,-2.6,...,1.2,5.8,13.6,5.5,0,True,True,1.0,0.463626,1
1845,20220120,5.2,1.4,0.7,1.7,2.9,3.8,1.3,4.0,3.9,...,3.3,4.2,10.3,2.9,0,True,True,1.0,0.553951,1
765,20190205,3.6,2.1,-2.1,1.9,3.7,3.1,-3.3,2.8,3.8,...,5.7,11.5,11.7,7.9,0,True,True,2.0,0.568644,1
417,20180222,-0.5,-2.4,-0.3,-0.5,-2.1,-1.2,-2.1,3.6,0.6,...,6.2,5.8,7.8,4.6,0,True,False,2.0,0.539751,1
443,20180320,1.3,-2.1,0.6,2.0,-0.2,-1.3,-2.8,6.7,2.4,...,10.1,7.4,8.8,2.4,0,True,False,3.0,0.370913,1
334,20171201,0.3,-0.7,0.0,-0.4,-0.5,0.8,0.1,3.7,1.9,...,3.7,2.5,5.2,2.8,0,True,True,12.0,0.65695,1
358,20171225,2.5,5.8,1.0,-0.4,5.9,4.3,0.3,4.9,2.6,...,3.2,5.0,7.9,9.2,0,True,True,12.0,0.255013,1
1849,20220124,-1.3,0.2,-2.8,0.7,0.5,0.7,-2.0,1.1,2.1,...,6.9,9.6,12.5,10.9,0,True,True,1.0,0.678219,1


Export to onnx

In [570]:


def sin_cos_shape_calculator(operator):
    """The output has 2 columns (sin and cos) for every 1 input column."""
    input_shape = operator.inputs[0].type.shape
    # [Batch, 1] becomes [Batch, 2]
    operator.outputs[0].type.shape = [input_shape[0], input_shape[1] + 1]


def sin_cos_converter(scope, operator, container):
    # Find 'month' input and group the rest
    month_input = None
    others = []

    for inp in operator.inputs:
        if inp.raw_name == "month":
            month_input = inp
        else:
            others.append(inp)

    if month_input is None:
        raise ValueError("Input 'month' not found in operator inputs.")

    # Math operations
    multiplier = np.array([2 * np.pi / 12], dtype=np.float32)
    scaled_x = OnnxMul(month_input, multiplier, op_version=container.target_opset)
    sin_node = OnnxSin(scaled_x, op_version=container.target_opset)
    cos_node = OnnxCos(scaled_x, op_version=container.target_opset)

    # Concat all into one feature vector for the classifier
    # Order: [sin, cos, col1, col2, ...]
    OnnxConcat(
        sin_node,
        cos_node,
        *others,
        axis=1,
        op_version=container.target_opset,
        output_names=operator.outputs[0].full_name,
    ).add_to(scope, container)


update_registered_converter(
    SinCosTransformer,
    "SinCosTransformer",
    sin_cos_shape_calculator,
    sin_cos_converter,
)

update_registered_converter(
    LGBMClassifier,
    "LightGbmLGBMClassifier",
    calculate_linear_classifier_output_shapes,
    convert_lightgbm,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

initial_types = [(col_name, FloatTensorType([None, 1])) for col_name in X_train.columns]
input_names_str = ",".join(X_train.columns)
model = convert_sklearn(
    calibrated_lgbm,
    "lightgbm_model",
    initial_types=initial_types,
    target_opset={"": 12, "ai.onnx.ml": 2},
    options=dict(zipmap=False),
    doc_string=f"Predict if the next day will be a red day.\n\nInput names: {input_names_str}",
)

# Save the modified model
onnx_model_path = "lgbm_model_red_days_2026_01_09.onnx"
onnx.save(model, onnx_model_path)

print(f"Modified ONNX model saved as {onnx_model_path}")

Modified ONNX model saved as lgbm_model_red_days_2026_01_09.onnx


In [571]:
# Load model to test it in python using onnxruntime:

sess = rt.InferenceSession(onnx_model_path)
input_names = [i.name for i in sess.get_inputs()]
output_name = sess.get_outputs()[1].name

# Test the model
input_data = {
    input_name: np.expand_dims(X[input_name], axis=-1).astype(np.float32)
    for input_name in input_names
}
pred_onnx = sess.run([output_name], input_data)[0]

y_pred = calibrated_lgbm.predict_proba(X.astype(np.float32))

# Compare the predictions
assert np.allclose(y_pred, pred_onnx, atol=1e-3, rtol=1e-2), (
    f"Predictions should be the same:\n{pred_onnx[0:10]}\n{y_pred[0:10]}"
)