# Evaluation Notebook
Use the model trained in the modeling notebook to make and evaluate predictions on a test dataset.

#### NOTE: The user must have split data into train/test datasets in the modeling notebook before running this notebook.

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 
## Instructions

1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform inference and evaluation. 
    - PROMOTE_MODEL is set to false. If evaluation is satisfactory, change the value at the end of the notebook and run the last 2 cells. If model is promoted, predictions on test data will be stored in a Snowflake table and a model monitor will be create to track future inference values.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [1]:
# Imports
import math
import json
from datetime import datetime
from snowflake.ml.registry import registry
from snowflake.ml.dataset import Dataset
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
from snowflake.snowpark import DataFrame as SnowparkDataFrame
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from forecast_model_builder.utils import connect, perform_inference

  from .autonotebook import tqdm as notebook_tqdm


-----
# SETUP
-----

In [2]:
# Name of project
PROJECT_SCHEMA = "TEST_PROJECT"

# Set warehouse
SESSION_WH = "FORECAST_MODEL_BUILDER_WH"

# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
# NOTE: Model version will be appended to the table name to save predictions from a particular run.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# Input data for inference
INFERENCE_DB = "FORECAST_MODEL_BUILDER"
INFERENCE_SCHEMA = "BASE"
INFERENCE_FV = "FORECAST_FEATURES"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# If using direct multistep forecasting, set LEAD to the lead model you wish to evaluate.
# Otherwise, set to 0
LEAD = 0

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

In [3]:
# Establish session
session = connect(connection_name="default")
session.use_database(MODEL_DB)
session_db = MODEL_DB
session.use_schema(PROJECT_SCHEMA)
session_schema = PROJECT_SCHEMA
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {SESSION_WH}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

Session db.schema: FORECAST_MODEL_BUILDER.TEST_PROJECT
Session warehouse: FORECAST_MODEL_BUILDER_WH
Current Datetime: 2026-02-17 11:55:57.621871


-----
# Establish objects needed for this run
-----

In [4]:
# Derived Objects

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
session.use_warehouse(SESSION_WH)
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).last()

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]

TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
PARTITION_COLUMNS = stored_constants["PARTITION_COLUMNS"]
ALL_EXOG_COLS_HAVE_FUTURE_VALS = stored_constants["ALL_EXOG_COLS_HAVE_FUTURE_VALS"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]

if not USE_CONTEXT:
    MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

if (not ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD==0):
    raise ValueError(
        """If using direct multistep modeling approach, LEAD must be set to a number 
        greater than 0 to filter results to a particular lead model"""
    )
if (ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD>0):
    raise ValueError(
        """If using global modeling approach, LEAD must be set to a 0"""
    )
# --------------------------------
# Get datasets
# --------------------------------

def load_df_from_ds(fully_qualified_name, version):
    ds_db, ds_schema, ds_name = fully_qualified_name.split('.')

    return Dataset(
        session=session,
        database=ds_db,
        schema=ds_schema,
        name=ds_name,
        selected_version=version
    ).read.to_snowpark_dataframe()

train_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['train_dataset']['name'],
    version=mv.show_metrics()['train_dataset']['version']
).drop("GROUP_IDENTIFIER")

test_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['test_dataset']['name'],
    version=mv.show_metrics()['test_dataset']['version']
).drop("GROUP_IDENTIFIER")

# Filter to a particular lead model if performing direct multi step forecasting
if LEAD > 0:
    train_df = train_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )
    test_df = test_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )

Session warehouse:          FORECAST_MODEL_BUILDER_WH

Model Version:              DULL_BIRD_3


-----
# Inference
-----

In [5]:
# ------------------------------------------------------------------------
# INFERENCE
# ------------------------------------------------------------------------

print("Predictions")
session.use_warehouse(INFERENCE_WH)

train_result = perform_inference(session, train_df, mv).with_column("DATASET",F.lit("TRAIN")).cache_result()
test_result = perform_inference(session, test_df, mv).with_column("DATASET",F.lit("TEST")).cache_result()
test_result.show(2)


inference_result = train_result.union_all_by_name(test_result)
session.use_warehouse(SESSION_WH)

Predictions
Inference input data row count: 344750
Number of end partition invocations to expect in the query profile: 1750
Inference input data row count: 22750
Number of end partition invocations to expect in the query profile: 250
-----------------------------------------------------------------------------------
|"_PRED_"           |"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"DATASET"  |
-----------------------------------------------------------------------------------
|644.90185546875    |STORE_ID_8_PRODUCT_ID_3    |2024-12-06 00:00:00  |TEST       |
|634.2355346679688  |STORE_ID_8_PRODUCT_ID_3    |2024-12-17 00:00:00  |TEST       |
-----------------------------------------------------------------------------------



In [6]:
# Get predicted vs. actual dataframes for train and test

sdf = train_df.union_all_by_name(test_df).select("GROUP_IDENTIFIER_STRING",TIME_PERIOD_COLUMN,TARGET_COLUMN)

pred_v_actuals = (
    inference_result
    .join(sdf, on=["GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN])
    .select(
        "GROUP_IDENTIFIER_STRING", 
        TIME_PERIOD_COLUMN, 
        TARGET_COLUMN,
        F.col("_PRED_").alias("PREDICTED"),
        "DATASET"
    )
)

inference_partition_count = pred_v_actuals.select("GROUP_IDENTIFIER_STRING").distinct().count()

training_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TRAIN")

test_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TEST")
print(f"Dataset has {inference_partition_count} partitions")
pred_v_actuals.show(3)

Dataset has 250 partitions
-------------------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"MODEL_TARGET"     |"PREDICTED"        |"DATASET"  |
-------------------------------------------------------------------------------------------------------
|STORE_ID_9_PRODUCT_ID_8    |2023-10-18 00:00:00  |876.3895263671875  |877.7786254882812  |TRAIN      |
|STORE_ID_9_PRODUCT_ID_8    |2023-07-01 00:00:00  |925.4633178710938  |926.1979370117188  |TRAIN      |
|STORE_ID_9_PRODUCT_ID_8    |2023-07-05 00:00:00  |913.06884765625    |912.70751953125    |TRAIN      |
-------------------------------------------------------------------------------------------------------



In [7]:
# Calculate weights of each partition for weighted metrics

total_window = Window.partition_by()

partition_weights = (
    pred_v_actuals
    .group_by("GROUP_IDENTIFIER_STRING")
    .agg(F.sum(TARGET_COLUMN).alias(f'PARTITION_{TARGET_COLUMN}_SUM'), F.min(TIME_PERIOD_COLUMN), F.max(TIME_PERIOD_COLUMN))
    .with_column(f"TOTAL_{TARGET_COLUMN}", F.sum(f'PARTITION_{TARGET_COLUMN}_SUM').over(total_window))
    .with_column("PARTITION_WEIGHT", F.col(f'PARTITION_{TARGET_COLUMN}_SUM')/F.col(f"TOTAL_{TARGET_COLUMN}"))
)

partition_weights.sort(F.col("PARTITION_WEIGHT").desc()).show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"PARTITION_MODEL_TARGET_SUM"  |"MIN(ORDER_TIMESTAMP)"  |"MAX(ORDER_TIMESTAMP)"  |"TOTAL_MODEL_TARGET"  |"PARTITION_WEIGHT"     |
-------------------------------------------------------------------------------------------------------------------------------------------------------------
|STORE_ID_17_PRODUCT_ID_1   |1456821.035583496             |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963    |0.00795719234225625    |
|STORE_ID_24_PRODUCT_ID_6   |1438105.7832641602            |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963    |0.007854969173588749   |
|STORE_ID_22_PRODUCT_ID_3   |1411616.508605957             |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963    |0.007710284103622175   |
|STORE_ID_21_PRODUCT_ID_9   |1404011.4907226562     

# Overall Performance

In [8]:
def produce_metrics(sdf: SnowparkDataFrame) -> SnowparkDataFrame:
    # Row-level metrics
    row_actual_v_fcst = (
        sdf
        .with_column("PRED_ERROR", F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        .with_column(
            "ABS_ERROR", F.abs(F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        )
        .with_column(
            "APE",
            F.when(F.col(TARGET_COLUMN) == 0, F.lit(None)).otherwise(
                F.abs(F.col("ABS_ERROR") / F.col(TARGET_COLUMN))
            ),
        )
        .with_column("SQ_ERROR", F.pow(F.col(TARGET_COLUMN) - F.col("PREDICTED"), 2))
    )
    
    # Metrics per partition
    partition_metrics = row_actual_v_fcst.group_by("GROUP_IDENTIFIER_STRING").agg(
        F.avg("APE").alias("MAPE"),
        F.avg("ABS_ERROR").alias("MAE"),
        F.sqrt(F.avg("SQ_ERROR")).alias("RMSE"),
        F.count("*").alias("TOTAL_PRED_COUNT"),
    )
    
    # Overall modeling process across all partitions
    overall_avg_metrics = partition_metrics.agg(
        F.avg("MAPE").alias("OVERALL_MAPE"),
        F.avg("MAE").alias("OVERALL_MAE"),
        F.avg("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("AVG"))
    
    overall_weighted_avg_metrics = (
        partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .agg(
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAPE")).alias("OVERALL_MAPE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAE")).alias("OVERALL_MAE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("RMSE")).alias("OVERALL_RMSE"),
                 )
            .with_column("AGGREGATION", F.lit("WEIGHTED_AVG"))
    )
    
    overall_median_metrics = partition_metrics.agg(
        F.median("MAPE").alias("OVERALL_MAPE"),
        F.median("MAE").alias("OVERALL_MAE"),
        F.median("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("MEDIAN"))
    
    overall_metrics = (
        overall_avg_metrics
            .union(overall_median_metrics)
            .union(overall_weighted_avg_metrics)
            .select("AGGREGATION", "OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
            .sort("AGGREGATION")
    )
    
    # Show the metrics
    if inference_partition_count == 1:
        print(
            "There is only 1 partition, so these values are the metrics for that single model:"
        )
        display(
            overall_median_metrics.select("OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
        )
    else:
        print("Avg and Median of each metric over all the partitions:")
        display(overall_metrics)

    return row_actual_v_fcst, partition_metrics

print("TRAINING SET")
train_metric_sdf, train_partition_metrics = produce_metrics(training_pred_v_actuals)
print("VALIDATION SET")
test_metric_sdf, test_partition_metrics = produce_metrics(test_pred_v_actuals)

TRAINING SET
Avg and Median of each metric over all the partitions:


<snowflake.snowpark.dataframe.DataFrame at 0x174f27f20>

VALIDATION SET
Avg and Median of each metric over all the partitions:


<snowflake.snowpark.dataframe.DataFrame at 0x174fa5400>

# Partition Performance

In [9]:
if (len(PARTITION_COLUMNS) > 0) & (inference_partition_count > 1):
    distribution_df = test_partition_metrics.to_pandas()
    
    table_to_show_sdf = (
        test_partition_metrics
        .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
        .with_column("WGHT_PCT", F.col("PARTITION_WEIGHT")*100)
        .select("GROUP_IDENTIFIER_STRING", "MAPE", "MAE", "RMSE", "WGHT_PCT")
    )
    table_df = table_to_show_sdf.to_pandas()
    
    metrics = ["MAPE", "MAE", "RMSE"]
    
    fig = go.Figure()
    
    for i, metric in enumerate(metrics):
        visible = (i == 0)
        fig.add_trace(go.Box(
            x=distribution_df[metric],
            name=metric,
            boxpoints="all",
            jitter=0.3,
            pointpos=-1.8,
            hovertext=distribution_df["GROUP_IDENTIFIER_STRING"],
            hoverinfo="text+x",
            visible=visible
        ))
    
    fig.update_layout(
        title="MAPE Distribution",
        template="plotly_white",
        updatemenus=[
            dict(
                active=0,
                buttons=[
                    dict(
                        label=metric,
                        method="update",
                        args=[
                            {"visible": [m == metric for m in metrics]},
                            {"title": f"{metric} Distribution",
                             "xaxis": {"title": metric, "range": [distribution_df[metric].min(), distribution_df[metric].max()]}}
                        ]
                    ) for metric in metrics
                ],
                direction="down",
                showactive=True,
                x=0.0,
                xanchor="left",
                y=1.15,
                yanchor="top"
            )
        ],
        xaxis=dict(
            rangeslider=dict(visible=True),
            title="MAPE"
        ),
        annotations=[
            dict(text="Metric:", x=0, xref="paper", y=1.12, yref="paper", showarrow=False, xanchor="right")
        ]
    )
    
    fig.show()

    print("## BEST Performing Partitions (sorted by MAPE)")
    display(table_df.sort_values("MAPE", key=abs).head(20))
    
    print("## WORST Performing Partitions (sorted by MAPE desc)")
    display(table_df.sort_values("MAPE", key=abs, ascending=False).head(20))

AttributeError: module 'modin' has no attribute 'pandas'

In [None]:
# ------------------------------------------------------------------------------
# Visualize individual partition actual vs pred on a time series line chart
# ------------------------------------------------------------------------------

pred_v_actuals_pdf = pred_v_actuals.to_pandas()
pred_v_actuals_pdf[TIME_PERIOD_COLUMN] = pd.to_datetime(pred_v_actuals_pdf[TIME_PERIOD_COLUMN])
partition_list = sorted(pred_v_actuals_pdf["GROUP_IDENTIFIER_STRING"].unique().tolist())

fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=("Line Plot: Actual & Predicted", "Scatter Plot: Actual vs. Predicted"),
    vertical_spacing=0.15
)

for i, partition in enumerate(partition_list):
    pdf = pred_v_actuals_pdf[pred_v_actuals_pdf["GROUP_IDENTIFIER_STRING"] == partition].sort_values(TIME_PERIOD_COLUMN)
    visible = (i == 0)
    
    split_date = pdf[pdf["DATASET"] == "TRAIN"][TIME_PERIOD_COLUMN].max()
    
    fig.add_trace(go.Scatter(
        x=pdf[TIME_PERIOD_COLUMN], y=pdf[TARGET_COLUMN],
        mode="lines", name="Actual", line=dict(color="blue"),
        visible=visible, legendgroup=partition, showlegend=True
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter(
        x=pdf[TIME_PERIOD_COLUMN], y=pdf["PREDICTED"],
        mode="lines", name="Predicted", line=dict(color="red"),
        visible=visible, legendgroup=partition, showlegend=True
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter(
        x=[split_date, split_date], y=[pdf[TARGET_COLUMN].min(), pdf[TARGET_COLUMN].max()],
        mode="lines", name="Forecast Start", line=dict(color="green", dash="dash"),
        visible=visible, legendgroup=partition, showlegend=True
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter(
        x=pdf[TARGET_COLUMN], y=pdf["PREDICTED"],
        mode="markers", name="Actual vs Pred", opacity=0.6,
        visible=visible, legendgroup=partition, showlegend=False
    ), row=2, col=1)
    
    min_val, max_val = pdf[TARGET_COLUMN].min(), pdf[TARGET_COLUMN].max()
    fig.add_trace(go.Scatter(
        x=[min_val, max_val], y=[min_val, max_val],
        mode="lines", name="y=x", line=dict(color="black", dash="dash"),
        visible=visible, legendgroup=partition, showlegend=True
    ), row=2, col=1)

traces_per_partition = 5

buttons = []
for i, partition in enumerate(partition_list):
    visibility = [False] * (len(partition_list) * traces_per_partition)
    for j in range(traces_per_partition):
        visibility[i * traces_per_partition + j] = True
    buttons.append(dict(
        label=partition,
        method="update",
        args=[{"visible": visibility}, {"title": f"Partition: {partition}"}]
    ))

fig.update_layout(
    height=800,
    title=f"Partition: {partition_list[0]}",
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        direction="down",
        showactive=True,
        x=0.0,
        xanchor="left",
        y=1.08,
        yanchor="top"
    )],
    annotations=[
        dict(text="Partition:", x=0, xref="paper", y=1.06, yref="paper", showarrow=False, xanchor="right")
    ]
)

fig.update_xaxes(title_text=TIME_PERIOD_COLUMN, row=1, col=1)
fig.update_xaxes(title_text=TARGET_COLUMN, row=2, col=1)
fig.update_yaxes(title_text="Value", row=1, col=1)
fig.update_yaxes(title_text="PREDICTED", row=2, col=1)

fig.show()

# Feature Importance

In [None]:
# Load model feature important data depending on use of model context or storage table

if USE_CONTEXT:
    model_obj = mv.load(force=True).context.model_refs
    model_data = [(
        part, 
        dict(feature_importance=dict(
            zip(ref.model.feature_names_in_, [float(val) for val in ref.model.feature_importances_])
        ))
    ) for part,ref in model_obj.items()]

    model_df = pd.DataFrame(model_data,columns=["GROUP_IDENTIFIER_STRING","METADATA"])
    model_df["MODEL_NAME"] = MODEL_NAME
    print(
        f"Feature Importances are from model version {model_version_nm} model context."
    )
else:
    models_sdf = (
        session.table(f"{MODEL_BINARY_STORAGE_TBL_NM}")
        .filter(F.col("MODEL_NAME") == MODEL_NAME)
        .filter(
            F.col("MODEL_VERSION")
            == reg.get_model(qualified_model_name).default.version_name
        )
    )
    model_df = models_sdf.select(
        "MODEL_NAME", "GROUP_IDENTIFIER_STRING", "METADATA"
    ).to_pandas()
    print(
        f"Feature Importances are for model version {reg.get_model(qualified_model_name).default.version_name} in table {MODEL_BINARY_STORAGE_TBL_NM}."
    )

# Filter models to given lead if using direct multistep modeling
if LEAD > 0:
    model_df = model_df[
        model_df["GROUP_IDENTIFIER_STRING"].str.endswith(f"LEAD_{LEAD}")
    ]

In [None]:
def preprocess_model_data(df):
    """Preprocess model data by extracting feature importance from the METADATA column."""
    df["FEATURE_IMPORTANCE"] = df["METADATA"].apply(
        lambda x: (
            json.loads(x).get("feature_importance", {})
            if isinstance(x, str)
            else x.get("feature_importance", {})
        )
    )

    feature_rows = []
    for _, row in df.iterrows():
        for feature, importance in row["FEATURE_IMPORTANCE"].items():
            feature_rows.append(
                {
                    "MODEL_NAME": row["MODEL_NAME"],
                    "GROUP_IDENTIFIER_STRING": row["GROUP_IDENTIFIER_STRING"],
                    "FEATURE": feature,
                    "IMPORTANCE": importance,
                }
            )

    feature_df = pd.DataFrame(feature_rows)
    return df, feature_df


def calculate_average_rank(feature_df):
    """Calculate the average rank and importance of features across different group partitions."""
    feature_df = feature_df.copy()
    feature_df.loc[:, "RANK"] = feature_df.groupby("GROUP_IDENTIFIER_STRING")[
        "IMPORTANCE"
    ].rank(ascending=False)

    avg_rank_df = (
        feature_df.groupby("FEATURE")
        .agg({"RANK": "mean", "IMPORTANCE": "mean"})
        .reset_index()
    )

    avg_rank_df.rename(
        columns={"RANK": "AVERAGE_RANK", "IMPORTANCE": "AVERAGE_IMPORTANCE"},
        inplace=True,
    )
    avg_rank_df = avg_rank_df.sort_values("AVERAGE_RANK", ascending=True)
    return feature_df, avg_rank_df


model_df, feature_df = preprocess_model_data(model_df)
partition_models = sorted(model_df["GROUP_IDENTIFIER_STRING"].unique().tolist())

_, avg_rank_df_all = calculate_average_rank(feature_df)

top_n = 20

fig = go.Figure()

agg_df = avg_rank_df_all.head(top_n).sort_values("AVERAGE_RANK", ascending=True)
fig.add_trace(go.Bar(
    y=agg_df["FEATURE"].tolist(),
    x=agg_df["AVERAGE_RANK"].tolist(),
    orientation="h",
    name="Aggregated",
    marker_color="steelblue",
    visible=True
))

buttons = [
    dict(
        label="Aggregated (All)",
        method="update",
        args=[
            {"y": [agg_df["FEATURE"].tolist()], "x": [agg_df["AVERAGE_RANK"].tolist()]},
            {"title.text": f"Top {top_n} Feature Importance (Aggregated by Average Rank)", "xaxis.title.text": "Average Rank"}
        ]
    )
]

for partition in partition_models:
    part_df = feature_df[feature_df["GROUP_IDENTIFIER_STRING"] == partition]
    part_df = part_df.sort_values("IMPORTANCE", ascending=False).head(top_n)
    part_df = part_df.sort_values("IMPORTANCE", ascending=True)
    
    label = partition if len(partition) <= 35 else partition[:32] + "..."
    buttons.append(dict(
        label=label,
        method="update",
        args=[
            {"y": [part_df["FEATURE"].tolist()], "x": [part_df["IMPORTANCE"].tolist()]},
            {"title.text": f"Top {top_n} Feature Importance: {partition}", "xaxis.title.text": "Importance"}
        ]
    ))

fig.update_layout(
    title=dict(text=f"Top {top_n} Feature Importance (Aggregated by Average Rank)", y=0.98, x=0.5, xanchor="center"),
    xaxis_title="Average Rank",
    yaxis_title="Feature",
    yaxis=dict(categoryorder="array", categoryarray=agg_df["FEATURE"].tolist()),
    height=650,
    margin=dict(l=200, r=50, t=120, b=50),
    template="plotly_white",
    updatemenus=[
        dict(
            active=0,
            buttons=buttons,
            direction="down",
            showactive=True,
            x=1.0,
            xanchor="right",
            y=1.18,
            yanchor="top"
        )
    ],
    annotations=[
        dict(text="Select Partition:", x=0.99, xref="paper", y=1.15, yref="paper", showarrow=False, xanchor="right")
    ]
)

fig.show()

print(f"\n## Aggregated Feature Importance (Top {top_n})")
display(avg_rank_df_all.head(top_n))

# Promote Model Version?

If set to True, model version evaluated in this notebook will be promoted to default and a new inference table will be created from the test data. A model monitor to track future inference results will also be created.

In [None]:
PROMOTE_MODEL_VERSION = False

In [None]:
if PROMOTE_MODEL_VERSION:
    m = reg.get_model(qualified_model_name)
    m.default = model_version_nm
    print(f"Model version {model_version_nm} promoted.")
    session.use_schema(MODEL_SCHEMA)
    source_name = f"{INFERENCE_RESULT_TBL_NM}_{MODEL_NAME}_{model_version_nm}"
    base_name = source_name + "_BASELINE"
    table_exist = session.sql(f"SHOW TABLES LIKE '{source_name}';").count() > 0
    if table_exist:
        table_data = session.table(source_name).select(TIME_PERIOD_COLUMN,"GROUP_IDENTIFIER_STRING")
        data_to_save = test_pred_v_actuals.join(table_data, on = [TIME_PERIOD_COLUMN, "GROUP_IDENTIFIER_STRING"], how="leftanti")
        data_to_save.drop("DATASET").write.save_as_table(source_name, mode="append")
    else:
        test_pred_v_actuals.drop("DATASET").write.save_as_table(source_name, mode="overwrite")
    session.sql(f"""
        CREATE OR REPLACE MODEL MONITOR {MODEL_NAME}_{model_version_nm}_MONITOR
        WITH
            MODEL={MODEL_NAME}
            VERSION={model_version_nm}
            FUNCTION=predict
            SOURCE={source_name}
            TIMESTAMP_COLUMN={TIME_PERIOD_COLUMN}
            PREDICTION_SCORE_COLUMNS=(PREDICTED)  
            ACTUAL_SCORE_COLUMNS=(MODEL_TARGET)
            SEGMENT_COLUMNS = (GROUP_IDENTIFIER_STRING)
            WAREHOUSE={SESSION_WH}
            REFRESH_INTERVAL='1 day'
            AGGREGATION_WINDOW='1 day';
    """).collect()
    print("Model monitor created")