# Setting enviroment

In [1]:
%reload_ext kedro.ipython

In [2]:
import pandas as pd
import numpy as np
# import shap

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from feature_engine.imputation import CategoricalImputer,ArbitraryNumberImputer
from feature_engine.encoding import RareLabelEncoder, CountFrequencyEncoder
from feature_engine.creation import MathFeatures, RelativeFeatures
from feature_engine.wrappers import SklearnTransformerWrapper

from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_predict
from scipy.stats import ks_2samp

import matplotlib.pyplot as plt
import seaborn as sns

import mlflow

import warnings
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import (
    RandomForestClassifier, ExtraTreesClassifier,
    HistGradientBoostingClassifier, GradientBoostingClassifier
)
from sklearn.linear_model import LogisticRegression
from olist_project.utils.model import mlflow_experiment_run_cv
from category_encoders.cat_boost import CatBoostEncoder


In [3]:
pd.set_option('display.max_columns', 110)
pd.set_option('display.max_rows', 110)

In [4]:
mlflow_tracking_uri = context.project_path.as_uri()+'/mlflow'
mlflow.set_tracking_uri(mlflow_tracking_uri)
experiment_name = 'model_selection'
mlflow.set_experiment(experiment_name)

[1m<[0m[1;95mExperiment:[0m[39m [0m[33martifact_location[0m[39m=[0m[32m'file:///home/bruno/Documents/Programming/Programming_projects/olist_project/mlflow/603732696037769194'[0m[39m, [0m[33mcreation_time[0m[39m=[0m[1;36m1735500393001[0m[39m, [0m[33mexperiment_id[0m[39m=[0m[32m'603732696037769194'[0m[39m, [0m[33mlast_update_time[0m[39m=[0m[1;36m1735500393001[0m[39m, [0m[33mlifecycle_stage[0m[39m=[0m[32m'active'[0m[39m, [0m[33mname[0m[39m=[0m[32m'model_selection'[0m[39m, [0m[33mtags[0m[39m=[0m[1;39m{[0m[1;39m}[0m[1m>[0m

# Load data

In [10]:
random_state = catalog.load('params:random_state')
id_col = catalog.load('params:audience_building.id_col')
cohort_col = catalog.load('params:audience_building.cohort_col')
target_name = catalog.load('params:modeling.target')
X_dev = catalog.load("X_train")
y_dev = catalog.load("y_train")
y_dev = y_dev[target_name]
id_model_dev = catalog.load("id_model_train")
cohort_dev = pd.to_datetime(id_model_dev[cohort_col], format='%Y%m')

X_oot = catalog.load("X_test_oot")
y_oot = catalog.load("y_test_oot")
y_oot = y_oot[target_name]
id_model_oot = catalog.load("id_model_test_oot")
cohort_oot = pd.to_datetime(id_model_oot[cohort_col], format='%Y%m')

# Model definitions

In [6]:
cat_vars = [col for col in X_dev.select_dtypes('category').columns]

def _inf_to_nan(X):
    return X.replace([np.inf, -np.inf], np.nan)
inf_to_nan_transformer = FunctionTransformer(_inf_to_nan)

def get_lgbm_model(X_dev, params=None):
    cat_vars = [col for col in X_dev.select_dtypes('category').columns]
    pipe_steps = []
    if cat_vars:
        pipe_steps.append(('imputer',CategoricalImputer()))
        pipe_steps.append(('encoder',RareLabelEncoder()))
    
    if not params:
        params = dict()
    pipe_steps.append(('estimator',LGBMClassifier(verbosity=-1, random_state=random_state,
                                                  **params)))

    model = Pipeline(pipe_steps)

    return model

In [9]:
models = {}
models['lgbm'] = Pipeline([
    ('imputer',CategoricalImputer()),
    ('encoder',RareLabelEncoder()),
    ('estimator',LGBMClassifier(verbosity=-1, random_state=random_state)),
])

models['lgbm_cat_enc'] = Pipeline([
    ('imputer',CategoricalImputer()),
    ('encoder',RareLabelEncoder()),
    ('cat_encoder',CatBoostEncoder()),
    ('estimator',LGBMClassifier(verbosity=-1, random_state=random_state)),
])

models['xgb'] = Pipeline([
    ('inf_to_nan_transformer',SklearnTransformerWrapper(transformer=inf_to_nan_transformer)),
    ('imputer',CategoricalImputer()),
    ('encoder',RareLabelEncoder()),
    ('encoder_freq',CountFrequencyEncoder()),
    ('estimator',XGBClassifier(random_state=random_state)),
])

models['xgb_cat_enc'] = Pipeline([
    ('inf_to_nan_transformer',SklearnTransformerWrapper(transformer=inf_to_nan_transformer)),
    ('imputer',CategoricalImputer()),
    ('encoder',RareLabelEncoder()),
    ('cat_encoder',CatBoostEncoder()),
    ('estimator',XGBClassifier(random_state=random_state)),
])

models['cat'] = Pipeline([
    ('imputer',CategoricalImputer()),
    ('encoder',RareLabelEncoder()),
    ('estimator',CatBoostClassifier(cat_features=cat_vars, verbose=0, random_state=random_state)),
])

In [16]:
models['lgbm_cat_enc'].fit(X_dev,y_dev)
y_proba = models['lgbm_cat_enc'].predict_proba(X_oot)[:,1]
roc_auc_score(y_oot,y_proba)

[1;36m0.9335920766032932[0m

In [18]:
models['lgbm_cat_enc'][:-1].transform(X_oot)

Unnamed: 0,ord_total_orders_m3,ord_nunique_cohorts_m3,ord_mean_estimated_days_to_order_delivery_m3,ord_std_estimated_days_to_order_delivery_m3,ord_max_estimated_days_to_order_delivery_m3,ord_min_estimated_days_to_order_delivery_m3,ord_mean_days_to_order_approval_m3,ord_std_days_to_order_approval_m3,ord_max_days_to_order_approval_m3,ord_min_days_to_order_approval_m3,ord_mean_days_to_order_posting_m3,ord_std_days_to_order_posting_m3,ord_max_days_to_order_posting_m3,ord_min_days_to_order_posting_m3,ord_mean_days_to_order_delivery_m3,ord_std_days_to_order_delivery_m3,ord_max_days_to_order_delivery_m3,ord_min_days_to_order_delivery_m3,ord_mean_diff_days_actual_estimated_delivery_m3,ord_std_diff_days_actual_estimated_delivery_m3,ord_max_diff_days_actual_estimated_delivery_m3,ord_min_diff_days_actual_estimated_delivery_m3,ord_recency_m3,ord_mean_frequency_m3,ord_rate_nunique_cohorts_m3,ord_slope_nunique_order_id_m3,ord_total_orders_m6,ord_nunique_cohorts_m6,ord_mean_estimated_days_to_order_delivery_m6,ord_std_estimated_days_to_order_delivery_m6,ord_max_estimated_days_to_order_delivery_m6,ord_min_estimated_days_to_order_delivery_m6,ord_mean_days_to_order_approval_m6,ord_std_days_to_order_approval_m6,ord_max_days_to_order_approval_m6,ord_min_days_to_order_approval_m6,ord_mean_days_to_order_posting_m6,ord_std_days_to_order_posting_m6,ord_max_days_to_order_posting_m6,ord_min_days_to_order_posting_m6,ord_mean_days_to_order_delivery_m6,ord_std_days_to_order_delivery_m6,ord_max_days_to_order_delivery_m6,ord_min_days_to_order_delivery_m6,ord_mean_diff_days_actual_estimated_delivery_m6,ord_std_diff_days_actual_estimated_delivery_m6,ord_max_diff_days_actual_estimated_delivery_m6,ord_min_diff_days_actual_estimated_delivery_m6,ord_recency_m6,ord_mean_frequency_m6,ord_rate_nunique_cohorts_m6,ord_slope_nunique_order_id_m6,ord_total_orders_m9,ord_nunique_cohorts_m9,ord_mean_estimated_days_to_order_delivery_m9,...,pay_max_value_m9,pay_min_value_m9,pay_max_payment_sequential_m9,pay_median_payment_sequential_m9,pay_min_payment_sequential_m9,pay_count_payment_credit_card_m9,pay_mean_installments_credit_card_m9,pay_std_installments_credit_card_m9,pay_max_installments_credit_card_m9,pay_min_installments_credit_card_m9,pay_sum_value_credit_card_m9,pay_mean_value_credit_card_m9,pay_std_value_credit_card_m9,pay_max_value_credit_card_m9,pay_min_value_credit_card_m9,pay_count_payment_not_credit_card_m9,pay_sum_value_not_credit_card_m9,pay_mean_value_not_credit_card_m9,pay_std_value_not_credit_card_m9,pay_max_value_not_credit_card_m9,pay_min_value_not_credit_card_m9,ctm_nunique_customer_unique_id_m3,ctm_nunique_customer_zip_code_prefix_dig_1_m3,ctm_nunique_customer_zip_code_prefix_dig_2_m3,ctm_nunique_customer_zip_code_prefix_dig_3_m3,ctm_nunique_customer_zip_code_prefix_dig_4_m3,ctm_nunique_customer_zip_code_prefix_dig_5_m3,ctm_nunique_customer_state_m3,ctm_nunique_customer_unique_id_m6,ctm_nunique_customer_zip_code_prefix_dig_1_m6,ctm_nunique_customer_zip_code_prefix_dig_2_m6,ctm_nunique_customer_zip_code_prefix_dig_3_m6,ctm_nunique_customer_zip_code_prefix_dig_4_m6,ctm_nunique_customer_zip_code_prefix_dig_5_m6,ctm_nunique_customer_state_m6,ctm_nunique_customer_unique_id_m9,ctm_nunique_customer_zip_code_prefix_dig_1_m9,ctm_nunique_customer_zip_code_prefix_dig_2_m9,ctm_nunique_customer_zip_code_prefix_dig_3_m9,ctm_nunique_customer_zip_code_prefix_dig_4_m9,ctm_nunique_customer_zip_code_prefix_dig_5_m9,ctm_nunique_customer_state_m9,geo_mean_distance_customer_seller_m3,geo_std_distance_customer_seller_m3,geo_max_distance_customer_seller_m3,geo_min_distance_customer_seller_m3,geo_mean_distance_customer_seller_m6,geo_std_distance_customer_seller_m6,geo_max_distance_customer_seller_m6,geo_min_distance_customer_seller_m6,geo_mean_distance_customer_seller_m9,geo_std_distance_customer_seller_m9,geo_max_distance_customer_seller_m9,geo_min_distance_customer_seller_m9,sel_seller_state
0,,,,,,,,,,,,,,,,,,,,,,,,,,,19.0,3.0,17.842105,9.239187,50.0,9.0,0.105263,0.315302,1.0,0.0,1.894737,1.100239,4.0,0.0,8.894737,5.743545,25.0,1.0,8.526316,5.048305,24.0,2.0,114.0,3.166667,0.500000,-1.857143,53.0,5.0,17.000000,...,127.45,2.00,3.0,1.0,1.0,44.0,2.090909,1.552465,6.0,1.0,2381.59,54.127045,22.314501,127.45,2.77,14.0,668.67,47.762143,21.130961,90.18,2.00,,,,,,,,19.0,7.0,16.0,19.0,19.0,19.0,7.0,53.0,8.0,28.0,49.0,53.0,53.0,9.0,,,,,396.415519,614.061641,2708.402561,18.566632,322.062512,526.838706,2708.402561,15.803204,0.318233
1,43.0,3.0,24.488372,6.616617,50.0,12.0,0.279070,0.734380,4.0,0.0,3.186047,1.930406,7.0,0.0,12.116279,8.056968,34.0,0.0,12.093023,6.900014,22.0,-12.0,0.0,14.333333,1.000000,3.5,74.0,6.0,22.527027,6.198150,50.0,12.0,0.243243,0.615125,4.0,0.0,3.121622,1.766916,8.0,0.0,11.108108,7.278355,34.0,0.0,11.148649,6.349808,22.0,-12.0,0.0,12.333333,1.000000,1.714286,89.0,9.0,22.528090,...,361.47,3.35,2.0,1.0,1.0,70.0,2.728571,2.232315,10.0,1.0,8146.30,116.375714,36.568900,194.92,20.14,23.0,2704.13,117.570870,69.021203,361.47,3.35,43.0,8.0,31.0,41.0,43.0,43.0,11.0,74.0,10.0,41.0,65.0,72.0,73.0,15.0,89.0,10.0,44.0,77.0,86.0,88.0,16.0,717.153340,532.207595,2551.745767,5.899857,643.158858,503.701970,2551.745767,4.101376,631.882484,477.524693,2551.745767,4.101376,0.311438
2,38.0,3.0,20.657895,5.901334,32.0,12.0,0.210526,0.576939,3.0,0.0,1.842105,1.461697,5.0,0.0,10.842105,10.468812,47.0,1.0,9.421053,8.410520,23.0,-21.0,1.0,12.666667,1.000000,2.5,53.0,5.0,19.830189,6.031124,32.0,9.0,0.301887,0.774222,3.0,0.0,1.905660,1.457925,5.0,0.0,10.830189,9.879536,47.0,1.0,8.603774,7.983973,23.0,-21.0,1.0,8.833333,0.833333,2.657143,53.0,5.0,19.830189,...,200.04,16.22,1.0,1.0,1.0,39.0,2.051282,1.834590,8.0,1.0,2528.83,64.841795,38.925570,200.04,16.22,13.0,734.16,56.473846,32.808791,112.44,21.71,37.0,8.0,23.0,34.0,37.0,37.0,7.0,52.0,8.0,28.0,45.0,52.0,52.0,8.0,52.0,8.0,28.0,45.0,52.0,52.0,8.0,335.695914,435.560842,2189.157132,7.165947,311.141991,384.371196,2189.157132,7.165947,311.141991,384.371196,2189.157132,7.165947,0.318233
3,102.0,3.0,21.147059,6.461458,44.0,13.0,0.156863,0.482262,3.0,0.0,3.225490,2.009418,9.0,0.0,11.647059,8.328226,52.0,1.0,9.137255,6.700139,25.0,-22.0,1.0,34.000000,1.000000,1.0,165.0,6.0,20.830303,6.457765,44.0,10.0,0.200000,0.586432,3.0,0.0,3.193939,1.834284,9.0,0.0,13.660606,13.838131,97.0,1.0,6.830303,13.476096,25.0,-80.0,1.0,27.500000,1.000000,5.057143,233.0,9.0,20.429185,...,459.01,2.08,3.0,1.0,1.0,181.0,4.104972,2.848982,10.0,1.0,33641.37,185.863923,84.740278,459.01,2.08,57.0,10070.65,176.678070,88.115675,439.76,20.04,99.0,10.0,36.0,81.0,95.0,99.0,14.0,161.0,10.0,58.0,121.0,149.0,160.0,18.0,227.0,10.0,66.0,161.0,205.0,224.0,20.0,370.041752,506.856880,2121.122218,4.074288,445.118097,561.589970,2337.880105,4.074288,446.706692,544.802385,2460.285152,4.074288,0.318233
4,167.0,3.0,44.892216,12.169006,81.0,27.0,0.371257,0.787638,5.0,0.0,14.407186,9.681155,82.0,0.0,24.041916,13.540814,82.0,0.0,20.688623,12.535686,54.0,-44.0,0.0,55.666667,1.000000,11.5,322.0,6.0,35.947205,13.687216,81.0,12.0,0.338509,0.714998,5.0,0.0,12.527950,7.608867,82.0,0.0,23.413043,14.796234,172.0,0.0,12.288820,16.910995,54.0,-148.0,0.0,53.666667,1.000000,4.342857,497.0,9.0,32.989940,...,2234.66,1.67,19.0,1.0,1.0,359.0,4.891365,3.427120,18.0,1.0,90799.07,252.922201,229.564539,2234.66,1.67,176.0,33343.55,189.451989,161.069071,1029.84,5.30,162.0,10.0,66.0,132.0,155.0,155.0,19.0,315.0,10.0,81.0,219.0,289.0,300.0,23.0,489.0,10.0,91.0,295.0,431.0,461.0,23.0,543.248078,613.409475,2473.456150,4.342648,575.464239,650.301646,2744.793085,4.342648,582.413872,631.790801,2744.793085,4.342648,0.318233
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1726,1.0,1.0,29.000000,,29.0,29.0,0.000000,,0.0,0.0,2.000000,,2.0,2.0,17.000000,,17.0,17.0,12.000000,,12.0,12.0,57.0,0.333333,0.333333,0.0,1.0,1.0,29.000000,,29.0,29.0,0.000000,,0.0,0.0,2.000000,,2.0,2.0,17.000000,,17.0,17.0,12.000000,,12.0,12.0,57.0,0.166667,0.166667,0.085714,1.0,1.0,29.000000,...,165.80,165.80,1.0,1.0,1.0,1.0,1.000000,,1.0,1.0,165.80,165.800000,,165.80,165.80,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,637.218643,,637.218643,637.218643,637.218643,,637.218643,637.218643,637.218643,,637.218643,637.218643,0.318233
1727,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,29.000000,...,47.32,47.32,1.0,1.0,1.0,1.0,4.000000,,4.0,4.0,47.32,47.320000,,47.32,47.32,,,,,,,,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,,,,,,,,,1132.358825,,1132.358825,1132.358825,0.318233
1728,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,27.000000,,27.0,27.0,0.000000,,0.0,0.0,1.000000,,1.0,1.0,36.000000,,36.0,36.0,-10.000000,,-10.0,-10.0,100.0,0.166667,0.166667,-0.028571,1.0,1.0,27.000000,...,79.56,79.56,1.0,1.0,1.0,1.0,2.000000,,2.0,2.0,79.56,79.560000,,79.56,79.56,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,,,,,958.997349,,958.997349,958.997349,958.997349,,958.997349,958.997349,0.318233
1729,1.0,1.0,28.000000,,28.0,28.0,0.000000,,0.0,0.0,10.000000,,10.0,10.0,18.000000,,18.0,18.0,9.000000,,9.0,9.0,48.0,0.333333,0.333333,0.0,1.0,1.0,28.000000,,28.0,28.0,0.000000,,0.0,0.0,10.000000,,10.0,10.0,18.000000,,18.0,18.0,9.000000,,9.0,9.0,48.0,0.166667,0.166667,0.085714,1.0,1.0,28.000000,...,86.15,86.15,1.0,1.0,1.0,1.0,3.000000,,3.0,3.0,86.15,86.150000,,86.15,86.15,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,399.262430,,399.262430,399.262430,399.262430,,399.262430,399.262430,399.262430,,399.262430,399.262430,0.318233
