# Model Comparisson



In [1]:
# Time Series Clustering: Setup and Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from tslearn.clustering import TimeSeriesKMeans
from tslearn.utils import to_time_series_dataset
import seaborn as sns

# Set plotting style
sns.set(style="whitegrid")

# Import required libraries
import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..", "..",))
sys.path.append(project_root)
from src.core.config_loader import ConfigLoader
from src.core.clients.bigquery import BigQueryClient
%matplotlib inline

bq_client = BigQueryClient()
config_loader = ConfigLoader()

from src.core.models.xgboost import XgboostModel

def load_model(model_name:str):
    model = XgboostModel().load_model(
        f'/Users/anapreciado/Desktop/porygon-demand-forecasting/src/artifacts/models/xgboost/{model_name}/model_0.pickle'
    )
    return model

Install h5py to use hdf5 features: http://docs.h5py.org/
  warn(h5py_msg)
INFO:src.core.clients.bigquery:BigQuery client initialized for project: porygon-pipelines


In [2]:
# Load models to compare

models_to_compare = {
    "all": load_model("20260131_133708_all"),
    "new": load_model("20260131_144504_top100"),
}

all_required_features = []

for name, model in models_to_compare.items():
    all_required_features = all_required_features + model.features


unique_required_features = list(set(all_required_features))

INFO:src.core.models.xgboost:Model loaded from /Users/anapreciado/Desktop/porygon-demand-forecasting/src/artifacts/models/xgboost/20260131_133708_all/model_0.pickle
INFO:src.core.models.xgboost:Model loaded from /Users/anapreciado/Desktop/porygon-demand-forecasting/src/artifacts/models/xgboost/20260131_144504_top100/model_0.pickle


In [3]:
sales_query = f"""
    SELECT 
    ctx_item_id,
    ctx_date_month,
    ctx_cat_id,
    ctx_dept_id,
    ctx_store_id,
    {config_loader.target_col},
    {config_loader.benchmark_col},
    {','.join(unique_required_features)}
    FROM `porygon-pipelines.walmart_training_tables.walmart_master_table` tgt
    WHERE 
        tgt.ctx_store_id = 'CA_1'
        AND is_stockout_tgt = 0
        AND fea_item_longevity_months >=3
        AND is_stockout_prev_3_m = 0
        AND ctx_date_month >= '{config_loader.val_start_date}' 
        AND ctx_date_month <= '{config_loader.val_end_date}'
        AND  {config_loader.target_col} <=420
"""


# Extract data
df = bq_client.load_from_query(sales_query)

INFO:src.core.clients.bigquery:Loaded 19340 rows from custom query.


In [4]:
from src.core.models.evaluation import error_for_group


df_results = pd.DataFrame()
for name, model in models_to_compare.items():
    df_error = pd.DataFrame(error_for_group(df, model), columns =[f"predictions_{name}"])
    df_results = pd.concat([df_results, df_error], axis = 1)




INFO:src.core.clients.bigquery:BigQuery client initialized for project: porygon-pipelines


In [5]:
df_results

Unnamed: 0,predictions_all,predictions_new
pred_mae,21.584373,21.653933
bench_mae,27.648811,27.648811
diff_mae,-6.064437,-5.994878
pred_rsme,35.141457,35.275345
bench_rsme,49.299468,49.299468
diff_rsme,-14.158011,-14.024123
pred_mdae,12.375878,12.411365
bench_mdae,16.0,16.0
diff_mdae,-3.624122,-3.588635
pred_mape,0.282197,0.283665


In [6]:
from src.core.models.evaluation import retrieve_error_per_group


df_results = pd.DataFrame()
for name, model in models_to_compare.items():
    print(name)
    display(retrieve_error_per_group(df, "ctx_cat_id", model))


all


  return df.groupby(groupby_col).apply(lambda group: error_for_group(group, model))


Unnamed: 0_level_0,pred_mae,bench_mae,diff_mae,pred_rsme,bench_rsme,diff_rsme,pred_mdae,bench_mdae,diff_mdae,pred_mape,bench_mape,diff_mape,pred_mdape,bench_mdape,diff_mdape
ctx_cat_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
FOODS,27.888716,36.223551,-8.334835,42.456821,61.807653,-19.350832,17.727428,22.0,-4.272572,0.251454,0.316506,-0.065052,0.179861,0.228571,-0.048711
HOBBIES,16.489933,20.96651,-4.476577,28.641068,38.561002,-9.919935,9.056427,11.0,-1.943573,0.380547,0.455995,-0.075448,0.221379,0.287129,-0.06575
HOUSEHOLD,17.257027,21.59721,-4.340183,28.559435,36.82205,-8.262615,9.947779,13.0,-3.052221,0.253908,0.317129,-0.063221,0.173709,0.224138,-0.050429


new


  return df.groupby(groupby_col).apply(lambda group: error_for_group(group, model))


Unnamed: 0_level_0,pred_mae,bench_mae,diff_mae,pred_rsme,bench_rsme,diff_rsme,pred_mdae,bench_mdae,diff_mdae,pred_mape,bench_mape,diff_mape,pred_mdape,bench_mdape,diff_mdape
ctx_cat_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
FOODS,27.941595,36.223551,-8.281956,42.548527,61.807653,-19.259126,17.847183,22.0,-4.152817,0.252094,0.316506,-0.064413,0.179455,0.228571,-0.049116
HOBBIES,16.603615,20.96651,-4.362896,29.143948,38.561002,-9.417054,9.026049,11.0,-1.973951,0.382387,0.455995,-0.073608,0.225362,0.287129,-0.061767
HOUSEHOLD,17.317448,21.59721,-4.279762,28.528671,36.82205,-8.293378,10.089266,13.0,-2.910734,0.256146,0.317129,-0.060983,0.175259,0.224138,-0.048879


In [7]:
df_results = pd.DataFrame()
for name, model in models_to_compare.items():
    print(name)
    display(retrieve_error_per_group(df, "ctx_dept_id", model))

all


  return df.groupby(groupby_col).apply(lambda group: error_for_group(group, model))


Unnamed: 0_level_0,pred_mae,bench_mae,diff_mae,pred_rsme,bench_rsme,diff_rsme,pred_mdae,bench_mdae,diff_mdae,pred_mape,bench_mape,diff_mape,pred_mdape,bench_mdape,diff_mdape
ctx_dept_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
FOODS_1,30.335514,37.939346,-7.603832,48.936489,65.999867,-17.063378,17.687912,22.0,-4.312088,0.269001,0.328559,-0.059557,0.185105,0.230769,-0.045664
FOODS_2,24.206619,31.363636,-7.157017,36.405254,52.786111,-16.380857,15.861504,20.0,-4.138496,0.260938,0.332836,-0.071898,0.194483,0.25,-0.055517
FOODS_3,29.074669,38.199308,-9.124639,43.37867,64.762321,-21.383651,18.733582,23.0,-4.266418,0.241933,0.305036,-0.063103,0.17249,0.214286,-0.041795
HOBBIES_1,18.038809,22.611358,-4.572549,30.893066,41.345519,-10.452453,9.99795,12.0,-2.00205,0.351485,0.414293,-0.062808,0.206271,0.25,-0.043729
HOBBIES_2,11.720586,15.901639,-4.181054,20.186014,28.316388,-8.130374,6.704884,9.0,-2.295116,0.470034,0.584404,-0.11437,0.279512,0.41565,-0.136139
HOUSEHOLD_1,21.45236,26.62083,-5.16847,34.14476,44.35036,-10.2056,12.993225,16.0,-3.006775,0.231793,0.285921,-0.054129,0.161645,0.196429,-0.034784
HOUSEHOLD_2,11.242032,14.39467,-3.152638,17.706316,21.902527,-4.196211,7.153622,10.0,-2.846378,0.285616,0.361873,-0.076257,0.197932,0.269231,-0.071299


  return df.groupby(groupby_col).apply(lambda group: error_for_group(group, model))


new


Unnamed: 0_level_0,pred_mae,bench_mae,diff_mae,pred_rsme,bench_rsme,diff_rsme,pred_mdae,bench_mdae,diff_mdae,pred_mape,bench_mape,diff_mape,pred_mdape,bench_mdape,diff_mdape
ctx_dept_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
FOODS_1,30.341461,37.939346,-7.597884,48.716721,65.999867,-17.283146,17.765503,22.0,-4.234497,0.268625,0.328559,-0.059934,0.18605,0.230769,-0.044719
FOODS_2,24.382126,31.363636,-6.981511,36.665794,52.786111,-16.120317,16.143929,20.0,-3.856071,0.26261,0.332836,-0.070226,0.197706,0.25,-0.052294
FOODS_3,29.078686,38.199308,-9.120622,43.494724,64.762321,-21.267596,18.833969,23.0,-4.166031,0.242329,0.305036,-0.062707,0.169565,0.214286,-0.04472
HOBBIES_1,18.230421,22.611358,-4.380937,31.617559,41.345519,-9.72796,9.998001,12.0,-2.001999,0.353913,0.414293,-0.060381,0.205434,0.25,-0.044566
HOBBIES_2,11.594305,15.901639,-4.307334,19.661116,28.316388,-8.655273,6.874136,9.0,-2.125864,0.470067,0.584404,-0.114337,0.29747,0.41565,-0.11818
HOUSEHOLD_1,21.532297,26.62083,-5.088533,34.128735,44.35036,-10.221625,13.465294,16.0,-2.534706,0.232499,0.285921,-0.053423,0.159879,0.196429,-0.03655
HOUSEHOLD_2,11.274471,14.39467,-3.120198,17.629751,21.902527,-4.272776,7.25153,10.0,-2.74847,0.29005,0.361873,-0.071823,0.196323,0.269231,-0.072908
