# Evaluation Notebook
Use the model trained in the modeling notebook to make and evaluate predictions on a test dataset.

#### NOTE: The user must have split data into train/test datasets in the modeling notebook before running this notebook.

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 
## Instructions

1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform inference and evaluation. Predictions will be stored in a Snowflake table.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [1]:
# Imports
import math
import json
from datetime import datetime
from snowflake.ml.registry import registry
from snowflake.ml.dataset import Dataset
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
from snowflake.snowpark import DataFrame as SnowparkDataFrame
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from forecast_model_builder.utils import connect

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Establish session
session = connect(connection_name="default")
session_db = session.connection.database
session_schema = session.connection.schema
session_wh = session.connection.warehouse
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

DatabaseError: 250001 (08001): Failed to connect to DB: SFSENORTHAMERICA-AFERAS_AWS1.snowflakecomputing.com:443. Incoming request with IP/Token 65.56.243.156 is not allowed to access Snowflake. Contact your account administrator. For more information about this error, go to https://community.snowflake.com/s/ip-xxxxxxxxxxxx-is-not-allowed-to-access.

-----
# SETUP
-----

In [None]:
# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
# NOTE: Model version will be appended to the table name to save predictions from a particular run.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# Input data for inference
INFERENCE_DB = "FORECAST_MODEL_BUILDER"
INFERENCE_SCHEMA = "BASE"
INFERENCE_FV = "FORECAST_FEATURES"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# If using direct multistep forecasting, set LEAD to the lead model you wish to evaluate.
# Otherwise, set to 0
LEAD = 0

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

-----
# Establish objects needed for this run
-----

In [None]:
# Derived Objects

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).default

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]

TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
PARTITION_COLUMNS = stored_constants["PARTITION_COLUMNS"]
EXOGENOUS_COLUMNS = stored_constants["EXOGENOUS_COLUMNS"]
ALL_EXOG_COLS_HAVE_FUTURE_VALS = stored_constants["ALL_EXOG_COLS_HAVE_FUTURE_VALS"]
CREATE_LAG_FEATURE = stored_constants["CREATE_LAG_FEATURE"]
CURRENT_FREQUENCY = stored_constants["CURRENT_FREQUENCY"]
ROLLUP_FREQUENCY = stored_constants["ROLLUP_FREQUENCY"]
ROLLUP_AGGREGATIONS = stored_constants["ROLLUP_AGGREGATIONS"]
FORECAST_HORIZON = stored_constants["FORECAST_HORIZON"]
XGB_PARAMS = stored_constants["XGB_PARAMS"]
INFERENCE_APPROX_BATCH_SIZE = stored_constants["INFERENCE_APPROX_BATCH_SIZE"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]

if not USE_CONTEXT:
    MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

if (not ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD==0):
    raise ValueError(
        """If using direct multistep modeling approach, LEAD must be set to a number 
        greater than 0 to filter results to a particular lead model"""
    )
if (ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD>0):
    raise ValueError(
        """If using global modeling approach, LEAD must be set to a 0"""
    )
# --------------------------------
# Get datasets
# --------------------------------

def load_df_from_ds(fully_qualified_name, version):
    ds_db, ds_schema, ds_name = fully_qualified_name.split('.')

    return Dataset(
        session=session,
        database=ds_db,
        schema=ds_schema,
        name=ds_name,
        selected_version=version
    ).read.to_snowpark_dataframe()

train_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['train_dataset']['name'],
    version=mv.show_metrics()['train_dataset']['version']
).drop("GROUP_IDENTIFIER")

test_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['test_dataset']['name'],
    version=mv.show_metrics()['test_dataset']['version']
).drop("GROUP_IDENTIFIER")

# Filter to a particular lead model if performing direct multi step forecasting
if LEAD > 0:
    train_df = train_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )
    test_df = test_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )


Session warehouse:          FORECAST_MODEL_BUILDER_WH

Model Version:              YOUNG_LEECH_2


-----
# Inference
-----

In [5]:
# ------------------------------------------------------------------------
# INFERENCE
# ------------------------------------------------------------------------

def perform_inference(inference_input_df):
    # If the inference dataset does not have the TARGET column already, add it and fill it with null values
    if TARGET_COLUMN not in inference_input_df.columns:
        inference_input_df = inference_input_df.with_column(TARGET_COLUMN, F.lit(None).cast(T.FloatType()))

    if not USE_CONTEXT:
        model_bytes_table = (
            session.table(MODEL_BINARY_STORAGE_TBL_NM)
            .filter(F.col("MODEL_NAME") == MODEL_NAME)
            .filter(F.col("MODEL_VERSION") == model_version_nm)
            .select("GROUP_IDENTIFIER_STRING", "MODEL_BINARY")
        )

        # NOTE: We inner joint to the model bytes table to ensure that we only try run inference on partitions that have a model.
        inference_input_df = inference_input_df.join(
            model_bytes_table, on=["GROUP_IDENTIFIER_STRING"], how="inner"
        )

    # Add a column called BATCH_GROUP,
    #   which has the property that for each unique value there are roughly the number of records specified in batch_size.
    # Use that to create a PARTITION_ID column that will be used to run inference in batches.
    # We do this to avoid running out of memory when performing inference on a large number of records.
    largest_partition_record_count = (
        inference_input_df.group_by("GROUP_IDENTIFIER_STRING")
        .agg(F.count("*").alias("PARTITION_RECORD_COUNT"))
        .agg(F.max("PARTITION_RECORD_COUNT").alias("MAX_PARTITION_RECORD_COUNT"))
        .collect()[0]["MAX_PARTITION_RECORD_COUNT"]
    )
    batch_size = INFERENCE_APPROX_BATCH_SIZE
    number_of_batches = math.ceil(largest_partition_record_count / batch_size)
    inference_input_df = (
        inference_input_df.with_column(
            "BATCH_GROUP", F.abs(F.random(123)) % F.lit(number_of_batches)
        )
        .with_column(
            "PARTITION_ID",
            F.concat_ws(
                F.lit("__"), F.col("GROUP_IDENTIFIER_STRING"), F.col("BATCH_GROUP")
            ),
        )
        .drop("RANDOM_NUMBER", "BATCH_GROUP")
    )

    # Look at a couple rows of the inference input data
    print(f"Inference input data row count: {inference_input_df.count()}")
    print(
        f"Number of end partition invocations to expect in the query profile: {inference_input_df.select('PARTITION_ID').distinct().count()}"
    )
    # Use the model to score the input data
    inference_result = mv.run(inference_input_df, partition_column="PARTITION_ID").select(
        "_PRED_",
        F.col("GROUP_IDENTIFIER_STRING_OUT_").alias("GROUP_IDENTIFIER_STRING"),
        F.col(f"{TIME_PERIOD_COLUMN}_OUT_").alias(TIME_PERIOD_COLUMN),
    )

    return inference_result


print("Predictions")
session.use_warehouse(INFERENCE_WH)

train_result = perform_inference(train_df).with_column("DATASET",F.lit("TRAIN")).cache_result()
test_result = perform_inference(test_df).with_column("DATASET",F.lit("TEST")).cache_result()
test_result.show(2)

session.use_warehouse(SESSION_WH)

Predictions
Inference input data row count: 344750


KeyboardInterrupt: 

In [None]:
# Write predictions for train and test to a Snowflake table.
# Right now this is set up to overwrite the table if it already exists.

inference_result = train_result.union_all_by_name(test_result)
TABLE_NAME_VERSION = INFERENCE_RESULT_TBL_NM+"_"+MODEL_NAME+"_"+model_version_nm
inference_result.write.save_as_table(
    TABLE_NAME_VERSION,
    mode="overwrite",
    comment='{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}',
)

print(
    f"Predictions written to table: {session_db}.{session_schema}.{TABLE_NAME_VERSION}"
)

# Look at a few rows of the snowflake table
print("Sample records:")
inference_result = session.table(TABLE_NAME_VERSION)
inference_result.show(3)

Predictions written to table: FORECAST_MODEL_BUILDER.TEST.FORECAST_RESULTS
Sample records:
-----------------------------------------------------------------------------------
|"_PRED_"           |"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"DATASET"  |
-----------------------------------------------------------------------------------
|562.4960327148438  |STORE_ID_25_PRODUCT_ID_8   |2023-10-02 00:00:00  |TRAIN      |
|588.8887939453125  |STORE_ID_25_PRODUCT_ID_8   |2023-07-25 00:00:00  |TRAIN      |
|577.991455078125   |STORE_ID_25_PRODUCT_ID_8   |2023-07-26 00:00:00  |TRAIN      |
-----------------------------------------------------------------------------------



In [None]:
# Get predicted vs. actual dataframes for train and test

sdf = train_df.union_all_by_name(test_df).select("GROUP_IDENTIFIER_STRING",TIME_PERIOD_COLUMN,TARGET_COLUMN)

pred_v_actuals = (
    inference_result
    .join(sdf, on=["GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN])
    .select(
        "GROUP_IDENTIFIER_STRING", 
        TIME_PERIOD_COLUMN, 
        TARGET_COLUMN,
        F.col("_PRED_").alias("PREDICTED"),
        "DATASET"
    )
)

inference_partition_count = pred_v_actuals.select("GROUP_IDENTIFIER_STRING").distinct().count()

training_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TRAIN")

test_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TEST")
print(f"Dataset has {inference_partition_count} partitions")
pred_v_actuals.show(3)

Test dataset has 250 partitions
-------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"TARGET"           |"PREDICTED"        |
-------------------------------------------------------------------------------------------
|STORE_ID_9_PRODUCT_ID_8    |2023-07-30 00:00:00  |947.8779296875     |941.7086181640625  |
|STORE_ID_9_PRODUCT_ID_8    |2023-05-26 00:00:00  |887.1766967773438  |888.5808715820312  |
|STORE_ID_9_PRODUCT_ID_8    |2023-06-01 00:00:00  |880.4500122070312  |885.3113403320312  |
-------------------------------------------------------------------------------------------



In [30]:
# Calculate weights of each partition for weighted metrics

total_window = Window.partition_by()

partition_weights = (
    pred_v_actuals
    .group_by("GROUP_IDENTIFIER_STRING")
    .agg(F.sum(TARGET_COLUMN).alias(f'PARTITION_{TARGET_COLUMN}_SUM'), F.min(TIME_PERIOD_COLUMN), F.max(TIME_PERIOD_COLUMN))
    .with_column(f"TOTAL_{TARGET_COLUMN}", F.sum(f'PARTITION_{TARGET_COLUMN}_SUM').over(total_window))
    .with_column("PARTITION_WEIGHT", F.col(f'PARTITION_{TARGET_COLUMN}_SUM')/F.col(f"TOTAL_{TARGET_COLUMN}"))
)

partition_weights.sort(F.col("PARTITION_WEIGHT").desc()).show()

-----------------------------------------------------------------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"PARTITION_TARGET_SUM"  |"MIN(ORDER_TIMESTAMP)"  |"MAX(ORDER_TIMESTAMP)"  |"TOTAL_TARGET"      |"PARTITION_WEIGHT"     |
-----------------------------------------------------------------------------------------------------------------------------------------------------
|STORE_ID_17_PRODUCT_ID_1   |1456821.035583496       |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.00795719234225625    |
|STORE_ID_24_PRODUCT_ID_6   |1438105.7832641602      |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.007854969173588749   |
|STORE_ID_22_PRODUCT_ID_3   |1411616.508605957       |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.007710284103622175   |
|STORE_ID_21_PRODUCT_ID_9   |1404011.4907226562      |2021-01-01 00:00:00     |2025-01-09 00:00:00  

# Overall Performance

In [None]:
def produce_metrics(sdf: SnowparkDataFrame) -> SnowparkDataFrame:
    # Row-level metrics
    row_actual_v_fcst = (
        sdf
        .with_column("PRED_ERROR", F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        .with_column(
            "ABS_ERROR", F.abs(F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        )
        .with_column(
            "APE",
            F.when(F.col(TARGET_COLUMN) == 0, F.lit(None)).otherwise(
                F.abs(F.col("ABS_ERROR") / F.col(TARGET_COLUMN))
            ),
        )
        .with_column("SQ_ERROR", F.pow(F.col(TARGET_COLUMN) - F.col("PREDICTED"), 2))
    )
    
    # Metrics per partition
    partition_metrics = row_actual_v_fcst.group_by("GROUP_IDENTIFIER_STRING").agg(
        F.avg("APE").alias("MAPE"),
        F.avg("ABS_ERROR").alias("MAE"),
        F.sqrt(F.avg("SQ_ERROR")).alias("RMSE"),
        F.count("*").alias("TOTAL_PRED_COUNT"),
    )
    
    # Overall modeling process across all partitions
    overall_avg_metrics = partition_metrics.agg(
        F.avg("MAPE").alias("OVERALL_MAPE"),
        F.avg("MAE").alias("OVERALL_MAE"),
        F.avg("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("AVG"))
    
    overall_weighted_avg_metrics = (
        partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .agg(
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAPE")).alias("OVERALL_MAPE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAE")).alias("OVERALL_MAE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("RMSE")).alias("OVERALL_RMSE"),
                 )
            .with_column("AGGREGATION", F.lit("WEIGHTED_AVG"))
    )
    
    overall_median_metrics = partition_metrics.agg(
        F.median("MAPE").alias("OVERALL_MAPE"),
        F.median("MAE").alias("OVERALL_MAE"),
        F.median("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("MEDIAN"))
    
    overall_metrics = (
        overall_avg_metrics
            .union(overall_median_metrics)
            .union(overall_weighted_avg_metrics)
            .select("AGGREGATION", "OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
            .sort("AGGREGATION")
    )
    
    # Show the metrics
    if inference_partition_count == 1:
        st.write(
            "There is only 1 partition, so these values are the metrics for that single model:"
        )
        st.dataframe(
            overall_median_metrics.select("OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE"),
            use_container_width=True,
        )
    else:
        st.write("Avg and Median of each metric over all the partitions:")
        st.dataframe(overall_metrics, use_container_width=True)

    return row_actual_v_fcst, partition_metrics

st.write("TRAINING SET")
train_metric_sdf, train_partition_metrics = produce_metrics(training_pred_v_actuals)
st.write("VALIDATION SET")
test_metric_sdf, test_partition_metrics = produce_metrics(test_pred_v_actuals)

2025-10-13 15:44:05.581 
  command:

    streamlit run /opt/anaconda3/envs/forecast/lib/python3.12/site-packages/ipykernel_launcher.py [ARGUMENTS]
2025-10-13 15:44:07.224 Please replace `use_container_width` with `width`.

`use_container_width` will be removed after 2025-12-31.

For `use_container_width=True`, use `width='stretch'`. For `use_container_width=False`, use `width='content'`.
2025-10-13 15:44:18.480 Please replace `use_container_width` with `width`.

`use_container_width` will be removed after 2025-12-31.

For `use_container_width=True`, use `width='stretch'`. For `use_container_width=False`, use `width='content'`.


# Partition Performance

In [None]:
if (len(PARTITION_COLUMNS) > 0) & (inference_partition_count > 1):
    # Metric Distribution plot with dynamic filtering
    metric = st.selectbox("Metric", ["MAPE", "MAE", "RMSE", "WGHT_PCT", "GROUP_IDENTIFIER_STRING"], key="metric_select_box_mnth")
    
    distribution_df = test_partition_metrics.to_pandas()
    
    if metric not in ["WGHT_PCT", "GROUP_IDENTIFIER_STRING"]:
        st.subheader(f"{metric} Distribution")
        
        # Add a slider to filter outliers
        value_min, value_max = st.slider(
            f"Filter {metric} range in plot:",
            float(distribution_df[metric].min()),
            float(distribution_df[metric].max()),
            (float(distribution_df[metric].min()), float(distribution_df[metric].max())),
        )
    
        # Filter the DataFrame based on the slider values
        filtered_df = distribution_df[
            (distribution_df[metric] >= value_min) & (distribution_df[metric] <= value_max)
        ]
    
        fig = px.box(
            filtered_df,
            x=metric,  # Horizontal orientation
            points="all",  # Show individual data points as dots
            title=f"{metric} Distribution ({value_min:.2f} - {value_max:.2f})",
            labels={metric: metric, "GROUP_IDENTIFIER_STRING": "Partition"},
            hover_data=["GROUP_IDENTIFIER_STRING"],  # Add this for hover info
        )

        fig.update_layout(template="plotly_white", showlegend=True)
        st.plotly_chart(fig, use_container_width=True)

        # Layout with two columns
        col1, col2 = st.columns(2)
    
        # Column 1: Tables
        table_to_show = (
            test_partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .with_column("WGHT_PCT", F.col("PARTITION_WEIGHT")*100)
            .select("GROUP_IDENTIFIER_STRING", "MAPE", "MAE", "RMSE", "WGHT_PCT")
        )
        with col1:
            # Look at the best performing partitions
            st.subheader("BEST Performing Partitions")
            st.dataframe(table_to_show.sort(F.abs(metric)))
        with col2:
            # Look at the worst performing partitions
            st.subheader("WORST Performing Partitions")
            st.dataframe(table_to_show.sort(F.abs(metric).desc()))

    else:
        table_to_show = (
            test_partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .with_column("WGHT_PCT", F.col("PARTITION_WEIGHT")*100)
            .select("GROUP_IDENTIFIER_STRING", "MAPE", "MAE", "RMSE", "WGHT_PCT")
        )
        
        st.subheader(f"Sorted by {metric}")
        st.dataframe(table_to_show.sort(metric))

In [None]:
# ------------------------------------------------------------------------------
# Visualize individual partition actual vs pred on a time series line chart
# ------------------------------------------------------------------------------

# Enter partition manually instead of using selectbox
partition_input = st.text_input("Enter Partition Name", key="partition_selector_1")
load_partition = st.button("Load Partition")

if load_partition and partition_input.strip():
    partition_choice = partition_input.strip()

    # Create a pandas dataframe
    partition_choice_df = (
        pred_v_actuals.filter(F.col("GROUP_IDENTIFIER_STRING") == partition_choice)
        .sort(TIME_PERIOD_COLUMN)
        .to_pandas()
    )
    partition_choice_df[TIME_PERIOD_COLUMN] = pd.to_datetime(partition_choice_df[TIME_PERIOD_COLUMN])

    tabs = st.tabs(
        [
            "Line Plot: Validation Actual & Predicted",
            "Scatter Plot: Validation Actual vs. Predicted",
        ]
    )

    # --- EDITED CODE FOR LINE PLOT ---
    # Create a Plotly line chart
    fig_line = px.line(
        partition_choice_df,
        x=TIME_PERIOD_COLUMN,
        y=[TARGET_COLUMN, "PREDICTED"],
        title="Validation Actual vs. Predicted"
    )

    split_date = partition_choice_df[
        partition_choice_df["DATASET"]=="TRAIN"
    ][TIME_PERIOD_COLUMN].max()

    # Add a dashed vertical line at the specified date
    fig_line.add_vline(
        x=split_date.timestamp() * 1000,
        line_dash="dash",
        line_color="red",
        annotation_text="Forecast Start",
        annotation_position="top left"
    )

    # Render the Plotly figure in Streamlit
    tabs[0].plotly_chart(fig_line, use_container_width=True)
   
    # ----------------------
    # Validation Actuals vs. Predictions Scatter Plot
    fig_scatter = px.scatter(
        partition_choice_df,
        x=TARGET_COLUMN,
        y="PREDICTED",
        title="Predicted vs. Actual",
        opacity=0.6,
        trendline="ols",
        hover_data=["PREDICTED", TARGET_COLUMN, TIME_PERIOD_COLUMN],
    )

    # Add expected trendline (y = x)
    min_visits = min(partition_choice_df[TARGET_COLUMN])
    max_visits = max(partition_choice_df[TARGET_COLUMN])

    fig_scatter.add_trace(
        go.Scatter(
            x=[min_visits, max_visits],
            y=[min_visits, max_visits],
            mode="lines",
            line=dict(color="black", dash="dash"),
            name="Expected Trend (y = x)",
            showlegend=True,
        )
    )

    tabs[1].plotly_chart(fig_scatter, use_container_width=True)

# Feature Importance

In [8]:
# Load model feature important data depending on use of model context or storage table

if USE_CONTEXT:
    model_obj = mv.load().context.model_refs
    model_data = [(
        part, 
        dict(feature_importance=dict(
            zip(ref.model.feature_names_in_, [float(val) for val in ref.model.feature_importances_])
        ))
    ) for part,ref in model_obj.items()]

    model_df = pd.DataFrame(model_data,columns=["GROUP_IDENTIFIER_STRING","METADATA"])
    model_df["MODEL_NAME"] = MODEL_NAME
    print(
        f"Feature Importances are from model version {model_version_nm} model context."
    )
else:
    models_sdf = (
        session.table(f"{MODEL_BINARY_STORAGE_TBL_NM}")
        .filter(F.col("MODEL_NAME") == MODEL_NAME)
        .filter(
            F.col("MODEL_VERSION")
            == reg.get_model(qualified_model_name).default.version_name
        )
    )
    model_df = models_sdf.select(
        "MODEL_NAME", "GROUP_IDENTIFIER_STRING", "METADATA"
    ).to_pandas()
    print(
        f"Feature Importances are for model version {reg.get_model(qualified_model_name).default.version_name} in table {MODEL_BINARY_STORAGE_TBL_NM}."
    )

# Filter models to given lead if using direct multistep modeling
if LEAD > 0:
    model_df = model_df[
        model_df["GROUP_IDENTIFIER_STRING"].str.endswith(f"LEAD_{LEAD}")
    ]

Feature Importances are from model version YOUNG_LEECH_2 in table MODEL_STORAGE_TEST_MODEL_1.


In [None]:
def preprocess_model_data(df):
    """Preprocess model data by extracting feature importance from the METADATA column.

    This function performs the following steps:
    1. Extracts the "feature_importance" dictionary from the "METADATA" column.
    2. Converts the extracted feature importance data into a new DataFrame where each row
       represents a feature and its corresponding importance for a specific model.

    Args:
        df (pd.DataFrame): Input DataFrame containing model data with at least
                           the columns "MODEL_NAME", "GROUP_IDENTIFIER_STRING",
                           and "METADATA".

    Returns:
        tuple:
            - pd.DataFrame: The original DataFrame with an additional "FEATURE_IMPORTANCE" column.
            - pd.DataFrame: A new DataFrame containing the extracted features and their importance,
              with columns ["MODEL_NAME", "GROUP_IDENTIFIER_STRING", "FEATURE", "IMPORTANCE"].

    """
    # Extract feature importance from METADATA
    df["FEATURE_IMPORTANCE"] = df["METADATA"].apply(
        lambda x: (
            json.loads(x).get("feature_importance", {})
            if isinstance(x, str)
            else x.get("feature_importance", {})
        )
    )

    # Explode feature importance into rows
    feature_rows = []
    for _, row in df.iterrows():
        for feature, importance in row["FEATURE_IMPORTANCE"].items():
            feature_rows.append(
                {
                    "MODEL_NAME": row["MODEL_NAME"],
                    "GROUP_IDENTIFIER_STRING": row["GROUP_IDENTIFIER_STRING"],
                    "FEATURE": feature,
                    "IMPORTANCE": importance,
                }
            )

    feature_df = pd.DataFrame(feature_rows)
    return df, feature_df


def calculate_average_rank(feature_df):
    """Calculate the average rank and importance of features across different group partitions.

    This function:
    1. Computes the rank of each feature within its "GROUP_IDENTIFIER_STRING" based on
       feature importance in descending order.
    2. Aggregates the average rank and average importance for each feature across all groups.
    3. Returns the feature DataFrame with calculated ranks and a summarized DataFrame
       sorted by average rank.

    Args:
        feature_df (pd.DataFrame): Input DataFrame containing extracted feature importance
                                   with at least the columns ["GROUP_IDENTIFIER_STRING",
                                   "FEATURE", "IMPORTANCE"].

    Returns:
        tuple:
            - pd.DataFrame: The input DataFrame with an additional "RANK" column.
            - pd.DataFrame: A new DataFrame containing features and their average rank and
              importance, with columns ["FEATURE", "AVERAGE_RANK", "AVERAGE_IMPORTANCE"].

    """
    feature_df = feature_df.copy()
    feature_df.loc[:, "RANK"] = feature_df.groupby("GROUP_IDENTIFIER_STRING")[
        "IMPORTANCE"
    ].rank(ascending=False)

    avg_rank_df = (
        feature_df.groupby("FEATURE")
        .agg({"RANK": "mean", "IMPORTANCE": "mean"})
        .reset_index()
    )

    avg_rank_df.rename(
        columns={"RANK": "AVERAGE_RANK", "IMPORTANCE": "AVERAGE_IMPORTANCE"},
        inplace=True,
    )
    avg_rank_df = avg_rank_df.sort_values("AVERAGE_RANK", ascending=True)
    return feature_df, avg_rank_df


def plot_feature_importance(df, is_aggregated=True, top_n=20):
    """Create a horizontal bar plot to visualize feature importance.

    This function generates a feature importance plot based on whether the data
    is aggregated (showing average ranks across groups) or unaggregated (showing
    importance for a selected partition).

    Args:
        df (pd.DataFrame): DataFrame containing feature importance data.
                           Expected columns:
                           - If `is_aggregated=True`: ["FEATURE", "AVERAGE_RANK"]
                           - If `is_aggregated=False`: ["FEATURE", "IMPORTANCE"]
        is_aggregated (bool, optional): If True, plots average rank of features
                                        across groups. If False, plots raw importance
                                        for a single partition. Default is True.
        top_n (int, optional): Number of top features to display in the plot.
                               Default is 20.

    Returns:
        plotly.graph_objects.Figure: A bar plot visualizing the top feature importance.

    """
    if is_aggregated:
        df = df.sort_values("AVERAGE_RANK", ascending=True).head(top_n)
        x_col = "AVERAGE_RANK"
        title = "Top Feature Importance (Aggregated by Average Rank)"
        fig = px.bar(
            df,
            x=x_col,
            y="FEATURE",
            orientation="h",
            title=title,
            labels={"FEATURE": "Feature", x_col: "Average Rank"},
        )

        fig.update_layout(
            yaxis=dict(categoryorder="total descending"),
            xaxis_title="Average Rank",
            yaxis_title="Feature",
            margin=dict(l=50, r=50, t=50, b=50),
        )
    else:
        df = df.sort_values("IMPORTANCE", ascending=False).head(top_n)
        x_col = "IMPORTANCE"
        title = "Top Feature Importance for Selected Partition"

        fig = px.bar(
            df,
            x=x_col,
            y="FEATURE",
            orientation="h",
            title=title,
            labels={"FEATURE": "Feature", x_col: "Importance"},
        )

        fig.update_layout(
            yaxis=dict(categoryorder="total ascending"),
            xaxis_title="Importance",
            yaxis_title="Feature",
            margin=dict(l=50, r=50, t=50, b=50),
        )

    return fig


# Load and preprocess the data
model_df, feature_df = preprocess_model_data(model_df)

# Select Partition Model ID
partition_models = model_df["GROUP_IDENTIFIER_STRING"].unique()
selected_partition_model = st.selectbox(
    "Select Partition", [None] + sorted(partition_models)
)

# Filter data based on selections
filtered_feature_df = feature_df
if selected_partition_model:
    filtered_feature_df = filtered_feature_df[
        filtered_feature_df["GROUP_IDENTIFIER_STRING"] == selected_partition_model
    ]

# Select Top N Features
top_n = st.slider("Number of Top Features to Show", min_value=5, max_value=50, value=20)


# Display Feature Importance
st.subheader("Feature Importance")

if selected_partition_model:
    fig = plot_feature_importance(filtered_feature_df, is_aggregated=False, top_n=top_n)
else:
    filtered_feature_df, avg_rank_df = calculate_average_rank(filtered_feature_df)
    fig = plot_feature_importance(avg_rank_df, is_aggregated=True, top_n=top_n)

st.plotly_chart(fig, use_container_width=True)

# Expander for Underlying Data
with st.expander("Show Underlying Data"):
    if selected_partition_model:
        st.dataframe(filtered_feature_df.sort_values("IMPORTANCE", ascending=False))
    else:
        tabs = st.tabs(["Average Importance", "Individual Importance"])
        tabs[0].dataframe(avg_rank_df.sort_values("AVERAGE_RANK", ascending=True))
        tabs[1].dataframe(
            filtered_feature_df.sort_values("IMPORTANCE", ascending=False)
        )