# Evaluation Notebook
Use the model trained in the modeling notebook to make and evaluate predictions on a test dataset.

#### NOTE: The user must have split data into train/test datasets in the modeling notebook before running this notebook.

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 
## Instructions

1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform inference and evaluation. 
    - PROMOTE_MODEL is set to false. If evaluation is satisfactory, change the value at the end of the notebook and run the last 2 cells. If model is promoted, predictions on test data will be stored in a Snowflake table and a model monitor will be create to track future inference values.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [None]:
# Imports
import math
import json
from datetime import datetime
from snowflake.ml.registry import registry
from snowflake.ml.dataset import Dataset
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
from snowflake.snowpark import DataFrame as SnowparkDataFrame
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from forecast_model_builder.utils import connect, perform_inference

-----
# SETUP
-----

In [3]:
# Name of project
PROJECT_SCHEMA = "TEST_PROJECT"

# Set warehouse
SESSION_WH = "FORECAST_MODEL_BUILDER_WH"

# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
# NOTE: Model version will be appended to the table name to save predictions from a particular run.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# Input data for inference
INFERENCE_DB = "FORECAST_MODEL_BUILDER"
INFERENCE_SCHEMA = "BASE"
INFERENCE_FV = "FORECAST_FEATURES"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# If using direct multistep forecasting, set LEAD to the lead model you wish to evaluate.
# Otherwise, set to 0
LEAD = 0

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

In [2]:
# Establish session
session = connect(connection_name="default")
session.use_database(MODEL_DB)
session_db = MODEL_DB
session.use_schema(PROJECT_SCHEMA)
session_schema = PROJECT_SCHEMA
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

-----
# Establish objects needed for this run
-----

In [None]:
# Derived Objects

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
session.use_warehouse(SESSION_WH)
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).last()

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]

TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
PARTITION_COLUMNS = stored_constants["PARTITION_COLUMNS"]
ALL_EXOG_COLS_HAVE_FUTURE_VALS = stored_constants["ALL_EXOG_COLS_HAVE_FUTURE_VALS"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]

if not USE_CONTEXT:
    MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

if (not ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD==0):
    raise ValueError(
        """If using direct multistep modeling approach, LEAD must be set to a number 
        greater than 0 to filter results to a particular lead model"""
    )
if (ALL_EXOG_COLS_HAVE_FUTURE_VALS) & (LEAD>0):
    raise ValueError(
        """If using global modeling approach, LEAD must be set to a 0"""
    )
# --------------------------------
# Get datasets
# --------------------------------

def load_df_from_ds(fully_qualified_name, version):
    ds_db, ds_schema, ds_name = fully_qualified_name.split('.')

    return Dataset(
        session=session,
        database=ds_db,
        schema=ds_schema,
        name=ds_name,
        selected_version=version
    ).read.to_snowpark_dataframe()

train_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['train_dataset']['name'],
    version=mv.show_metrics()['train_dataset']['version']
).drop("GROUP_IDENTIFIER")

test_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['test_dataset']['name'],
    version=mv.show_metrics()['test_dataset']['version']
).drop("GROUP_IDENTIFIER")

# Filter to a particular lead model if performing direct multi step forecasting
if LEAD > 0:
    train_df = train_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )
    test_df = test_df.filter(
            F.col("GROUP_IDENTIFIER_STRING").endswith(f"LEAD_{LEAD}")
        )

-----
# Inference
-----

In [5]:
# ------------------------------------------------------------------------
# INFERENCE
# ------------------------------------------------------------------------

print("Predictions")
session.use_warehouse(INFERENCE_WH)

train_result = perform_inference(session, train_df, mv).with_column("DATASET",F.lit("TRAIN")).cache_result()
test_result = perform_inference(session, test_df, mv).with_column("DATASET",F.lit("TEST")).cache_result()
test_result.show(2)


inference_result = train_result.union_all_by_name(test_result)
session.use_warehouse(SESSION_WH)

In [None]:
# Get predicted vs. actual dataframes for train and test

sdf = train_df.union_all_by_name(test_df).select("GROUP_IDENTIFIER_STRING",TIME_PERIOD_COLUMN,TARGET_COLUMN)

pred_v_actuals = (
    inference_result
    .join(sdf, on=["GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN])
    .select(
        "GROUP_IDENTIFIER_STRING", 
        TIME_PERIOD_COLUMN, 
        TARGET_COLUMN,
        F.col("_PRED_").alias("PREDICTED"),
        "DATASET"
    )
)

inference_partition_count = pred_v_actuals.select("GROUP_IDENTIFIER_STRING").distinct().count()

training_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TRAIN")

test_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TEST")
print(f"Dataset has {inference_partition_count} partitions")
pred_v_actuals.show(3)

In [30]:
# Calculate weights of each partition for weighted metrics

total_window = Window.partition_by()

partition_weights = (
    pred_v_actuals
    .group_by("GROUP_IDENTIFIER_STRING")
    .agg(F.sum(TARGET_COLUMN).alias(f'PARTITION_{TARGET_COLUMN}_SUM'), F.min(TIME_PERIOD_COLUMN), F.max(TIME_PERIOD_COLUMN))
    .with_column(f"TOTAL_{TARGET_COLUMN}", F.sum(f'PARTITION_{TARGET_COLUMN}_SUM').over(total_window))
    .with_column("PARTITION_WEIGHT", F.col(f'PARTITION_{TARGET_COLUMN}_SUM')/F.col(f"TOTAL_{TARGET_COLUMN}"))
)

partition_weights.sort(F.col("PARTITION_WEIGHT").desc()).show()

# Overall Performance

In [None]:
def produce_metrics(sdf: SnowparkDataFrame) -> SnowparkDataFrame:
    # Row-level metrics
    row_actual_v_fcst = (
        sdf
        .with_column("PRED_ERROR", F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        .with_column(
            "ABS_ERROR", F.abs(F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        )
        .with_column(
            "APE",
            F.when(F.col(TARGET_COLUMN) == 0, F.lit(None)).otherwise(
                F.abs(F.col("ABS_ERROR") / F.col(TARGET_COLUMN))
            ),
        )
        .with_column("SQ_ERROR", F.pow(F.col(TARGET_COLUMN) - F.col("PREDICTED"), 2))
    )
    
    # Metrics per partition
    partition_metrics = row_actual_v_fcst.group_by("GROUP_IDENTIFIER_STRING").agg(
        F.avg("APE").alias("MAPE"),
        F.avg("ABS_ERROR").alias("MAE"),
        F.sqrt(F.avg("SQ_ERROR")).alias("RMSE"),
        F.count("*").alias("TOTAL_PRED_COUNT"),
    )
    
    # Overall modeling process across all partitions
    overall_avg_metrics = partition_metrics.agg(
        F.avg("MAPE").alias("OVERALL_MAPE"),
        F.avg("MAE").alias("OVERALL_MAE"),
        F.avg("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("AVG"))
    
    overall_weighted_avg_metrics = (
        partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .agg(
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAPE")).alias("OVERALL_MAPE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAE")).alias("OVERALL_MAE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("RMSE")).alias("OVERALL_RMSE"),
                 )
            .with_column("AGGREGATION", F.lit("WEIGHTED_AVG"))
    )
    
    overall_median_metrics = partition_metrics.agg(
        F.median("MAPE").alias("OVERALL_MAPE"),
        F.median("MAE").alias("OVERALL_MAE"),
        F.median("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("MEDIAN"))
    
    overall_metrics = (
        overall_avg_metrics
            .union(overall_median_metrics)
            .union(overall_weighted_avg_metrics)
            .select("AGGREGATION", "OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
            .sort("AGGREGATION")
    )
    
    # Show the metrics
    if inference_partition_count == 1:
        print(
            "There is only 1 partition, so these values are the metrics for that single model:"
        )
        display(
            overall_median_metrics.select("OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
        )
    else:
        print("Avg and Median of each metric over all the partitions:")
        display(overall_metrics)

    return row_actual_v_fcst, partition_metrics

print("TRAINING SET")
train_metric_sdf, train_partition_metrics = produce_metrics(training_pred_v_actuals)
print("VALIDATION SET")
test_metric_sdf, test_partition_metrics = produce_metrics(test_pred_v_actuals)

# Partition Performance

In [None]:
if (len(PARTITION_COLUMNS) > 0) & (inference_partition_count > 1):
    # Metric Distribution plot - set metric parameter here
    metric = "MAPE"  # Options: "MAPE", "MAE", "RMSE", "WGHT_PCT", "GROUP_IDENTIFIER_STRING"
    
    distribution_df = test_partition_metrics.to_pandas()
    
    if metric not in ["WGHT_PCT", "GROUP_IDENTIFIER_STRING"]:
        print(f"## {metric} Distribution")
        
        value_min = float(distribution_df[metric].min())
        value_max = float(distribution_df[metric].max())
    
        # Filter the DataFrame based on the range values
        filtered_df = distribution_df[
            (distribution_df[metric] >= value_min) & (distribution_df[metric] <= value_max)
        ]
    
        fig = px.box(
            filtered_df,
            x=metric,  # Horizontal orientation
            points="all",  # Show individual data points as dots
            title=f"{metric} Distribution ({value_min:.2f} - {value_max:.2f})",
            labels={metric: metric, "GROUP_IDENTIFIER_STRING": "Partition"},
            hover_data=["GROUP_IDENTIFIER_STRING"],  # Add this for hover info
        )

        fig.update_layout(template="plotly_white", showlegend=True)
        fig.show()

        # Tables
        table_to_show = (
            test_partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .with_column("WGHT_PCT", F.col("PARTITION_WEIGHT")*100)
            .select("GROUP_IDENTIFIER_STRING", "MAPE", "MAE", "RMSE", "WGHT_PCT")
        )
        
        # Look at the best performing partitions
        print("## BEST Performing Partitions")
        display(table_to_show.sort(F.abs(metric)))
        
        # Look at the worst performing partitions
        print("## WORST Performing Partitions")
        display(table_to_show.sort(F.abs(metric).desc()))

    else:
        table_to_show = (
            test_partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .with_column("WGHT_PCT", F.col("PARTITION_WEIGHT")*100)
            .select("GROUP_IDENTIFIER_STRING", "MAPE", "MAE", "RMSE", "WGHT_PCT")
        )
        
        print(f"## Sorted by {metric}")
        display(table_to_show.sort(metric))

In [None]:
# ------------------------------------------------------------------------------
# Visualize individual partition actual vs pred on a time series line chart
# ------------------------------------------------------------------------------

# Enter partition manually - set partition name here
partition_input = ""  # Set partition name to analyze, e.g., "PARTITION_1"

if partition_input.strip():
    partition_choice = partition_input.strip()

    # Create a pandas dataframe
    partition_choice_df = (
        pred_v_actuals.filter(F.col("GROUP_IDENTIFIER_STRING") == partition_choice)
        .sort(TIME_PERIOD_COLUMN)
        .to_pandas()
    )
    partition_choice_df[TIME_PERIOD_COLUMN] = pd.to_datetime(partition_choice_df[TIME_PERIOD_COLUMN])

    # --- LINE PLOT ---
    # Create a Plotly line chart
    fig_line = px.line(
        partition_choice_df,
        x=TIME_PERIOD_COLUMN,
        y=[TARGET_COLUMN, "PREDICTED"],
        title="Validation Actual vs. Predicted"
    )

    split_date = partition_choice_df[
        partition_choice_df["DATASET"]=="TRAIN"
    ][TIME_PERIOD_COLUMN].max()

    # Add a dashed vertical line at the specified date
    fig_line.add_vline(
        x=split_date.timestamp() * 1000,
        line_dash="dash",
        line_color="red",
        annotation_text="Forecast Start",
        annotation_position="top left"
    )

    print("## Line Plot: Validation Actual & Predicted")
    fig_line.show()
   
    # ----------------------
    # Validation Actuals vs. Predictions Scatter Plot
    fig_scatter = px.scatter(
        partition_choice_df,
        x=TARGET_COLUMN,
        y="PREDICTED",
        title="Predicted vs. Actual",
        opacity=0.6,
        trendline="ols",
        hover_data=["PREDICTED", TARGET_COLUMN, TIME_PERIOD_COLUMN],
    )

    # Add expected trendline (y = x)
    min_visits = min(partition_choice_df[TARGET_COLUMN])
    max_visits = max(partition_choice_df[TARGET_COLUMN])

    fig_scatter.add_trace(
        go.Scatter(
            x=[min_visits, max_visits],
            y=[min_visits, max_visits],
            mode="lines",
            line=dict(color="black", dash="dash"),
            name="Expected Trend (y = x)",
            showlegend=True,
        )
    )

    print("## Scatter Plot: Validation Actual vs. Predicted")
    fig_scatter.show()

# Feature Importance

In [None]:
# Load model feature important data depending on use of model context or storage table

if USE_CONTEXT:
    model_obj = mv.load(force=True).context.model_refs
    model_data = [(
        part, 
        dict(feature_importance=dict(
            zip(ref.model.feature_names_in_, [float(val) for val in ref.model.feature_importances_])
        ))
    ) for part,ref in model_obj.items()]

    model_df = pd.DataFrame(model_data,columns=["GROUP_IDENTIFIER_STRING","METADATA"])
    model_df["MODEL_NAME"] = MODEL_NAME
    print(
        f"Feature Importances are from model version {model_version_nm} model context."
    )
else:
    models_sdf = (
        session.table(f"{MODEL_BINARY_STORAGE_TBL_NM}")
        .filter(F.col("MODEL_NAME") == MODEL_NAME)
        .filter(
            F.col("MODEL_VERSION")
            == reg.get_model(qualified_model_name).default.version_name
        )
    )
    model_df = models_sdf.select(
        "MODEL_NAME", "GROUP_IDENTIFIER_STRING", "METADATA"
    ).to_pandas()
    print(
        f"Feature Importances are for model version {reg.get_model(qualified_model_name).default.version_name} in table {MODEL_BINARY_STORAGE_TBL_NM}."
    )

# Filter models to given lead if using direct multistep modeling
if LEAD > 0:
    model_df = model_df[
        model_df["GROUP_IDENTIFIER_STRING"].str.endswith(f"LEAD_{LEAD}")
    ]

In [None]:
def preprocess_model_data(df):
    """Preprocess model data by extracting feature importance from the METADATA column.

    This function performs the following steps:
    1. Extracts the "feature_importance" dictionary from the "METADATA" column.
    2. Converts the extracted feature importance data into a new DataFrame where each row
       represents a feature and its corresponding importance for a specific model.

    Args:
        df (pd.DataFrame): Input DataFrame containing model data with at least
                           the columns "MODEL_NAME", "GROUP_IDENTIFIER_STRING",
                           and "METADATA".

    Returns:
        tuple:
            - pd.DataFrame: The original DataFrame with an additional "FEATURE_IMPORTANCE" column.
            - pd.DataFrame: A new DataFrame containing the extracted features and their importance,
              with columns ["MODEL_NAME", "GROUP_IDENTIFIER_STRING", "FEATURE", "IMPORTANCE"].

    """
    # Extract feature importance from METADATA
    df["FEATURE_IMPORTANCE"] = df["METADATA"].apply(
        lambda x: (
            json.loads(x).get("feature_importance", {})
            if isinstance(x, str)
            else x.get("feature_importance", {})
        )
    )

    # Explode feature importance into rows
    feature_rows = []
    for _, row in df.iterrows():
        for feature, importance in row["FEATURE_IMPORTANCE"].items():
            feature_rows.append(
                {
                    "MODEL_NAME": row["MODEL_NAME"],
                    "GROUP_IDENTIFIER_STRING": row["GROUP_IDENTIFIER_STRING"],
                    "FEATURE": feature,
                    "IMPORTANCE": importance,
                }
            )

    feature_df = pd.DataFrame(feature_rows)
    return df, feature_df


def calculate_average_rank(feature_df):
    """Calculate the average rank and importance of features across different group partitions.

    This function:
    1. Computes the rank of each feature within its "GROUP_IDENTIFIER_STRING" based on
       feature importance in descending order.
    2. Aggregates the average rank and average importance for each feature across all groups.
    3. Returns the feature DataFrame with calculated ranks and a summarized DataFrame
       sorted by average rank.

    Args:
        feature_df (pd.DataFrame): Input DataFrame containing extracted feature importance
                                   with at least the columns ["GROUP_IDENTIFIER_STRING",
                                   "FEATURE", "IMPORTANCE"].

    Returns:
        tuple:
            - pd.DataFrame: The input DataFrame with an additional "RANK" column.
            - pd.DataFrame: A new DataFrame containing features and their average rank and
              importance, with columns ["FEATURE", "AVERAGE_RANK", "AVERAGE_IMPORTANCE"].

    """
    feature_df = feature_df.copy()
    feature_df.loc[:, "RANK"] = feature_df.groupby("GROUP_IDENTIFIER_STRING")[
        "IMPORTANCE"
    ].rank(ascending=False)

    avg_rank_df = (
        feature_df.groupby("FEATURE")
        .agg({"RANK": "mean", "IMPORTANCE": "mean"})
        .reset_index()
    )

    avg_rank_df.rename(
        columns={"RANK": "AVERAGE_RANK", "IMPORTANCE": "AVERAGE_IMPORTANCE"},
        inplace=True,
    )
    avg_rank_df = avg_rank_df.sort_values("AVERAGE_RANK", ascending=True)
    return feature_df, avg_rank_df


def plot_feature_importance(df, is_aggregated=True, top_n=20):
    """Create a horizontal bar plot to visualize feature importance.

    This function generates a feature importance plot based on whether the data
    is aggregated (showing average ranks across groups) or unaggregated (showing
    importance for a selected partition).

    Args:
        df (pd.DataFrame): DataFrame containing feature importance data.
                           Expected columns:
                           - If `is_aggregated=True`: ["FEATURE", "AVERAGE_RANK"]
                           - If `is_aggregated=False`: ["FEATURE", "IMPORTANCE"]
        is_aggregated (bool, optional): If True, plots average rank of features
                                        across groups. If False, plots raw importance
                                        for a single partition. Default is True.
        top_n (int, optional): Number of top features to display in the plot.
                               Default is 20.

    Returns:
        plotly.graph_objects.Figure: A bar plot visualizing the top feature importance.

    """
    if is_aggregated:
        df = df.sort_values("AVERAGE_RANK", ascending=True).head(top_n)
        x_col = "AVERAGE_RANK"
        title = "Top Feature Importance (Aggregated by Average Rank)"
        fig = px.bar(
            df,
            x=x_col,
            y="FEATURE",
            orientation="h",
            title=title,
            labels={"FEATURE": "Feature", x_col: "Average Rank"},
        )

        fig.update_layout(
            yaxis=dict(categoryorder="total descending"),
            xaxis_title="Average Rank",
            yaxis_title="Feature",
            margin=dict(l=50, r=50, t=50, b=50),
        )
    else:
        df = df.sort_values("IMPORTANCE", ascending=False).head(top_n)
        x_col = "IMPORTANCE"
        title = "Top Feature Importance for Selected Partition"

        fig = px.bar(
            df,
            x=x_col,
            y="FEATURE",
            orientation="h",
            title=title,
            labels={"FEATURE": "Feature", x_col: "Importance"},
        )

        fig.update_layout(
            yaxis=dict(categoryorder="total ascending"),
            xaxis_title="Importance",
            yaxis_title="Feature",
            margin=dict(l=50, r=50, t=50, b=50),
        )

    return fig


# Load and preprocess the data
model_df, feature_df = preprocess_model_data(model_df)

# Select Partition Model ID - set partition here or leave as None for aggregated view
partition_models = model_df["GROUP_IDENTIFIER_STRING"].unique()
selected_partition_model = None  # Set to partition name to view individual, or None for aggregated

# Filter data based on selections
filtered_feature_df = feature_df
if selected_partition_model:
    filtered_feature_df = filtered_feature_df[
        filtered_feature_df["GROUP_IDENTIFIER_STRING"] == selected_partition_model
    ]

# Select Top N Features
top_n = 20  # Number of top features to show (5-50)


# Display Feature Importance
print("## Feature Importance")

if selected_partition_model:
    fig = plot_feature_importance(filtered_feature_df, is_aggregated=False, top_n=top_n)
else:
    filtered_feature_df, avg_rank_df = calculate_average_rank(filtered_feature_df)
    fig = plot_feature_importance(avg_rank_df, is_aggregated=True, top_n=top_n)

fig.show()

# Show Underlying Data
print("## Underlying Data")
if selected_partition_model:
    display(filtered_feature_df.sort_values("IMPORTANCE", ascending=False))
else:
    print("### Average Importance")
    display(avg_rank_df.sort_values("AVERAGE_RANK", ascending=True))
    print("### Individual Importance")
    display(filtered_feature_df.sort_values("IMPORTANCE", ascending=False))

# Promote Model Version?

If set to True, model version evaluated in this notebook will be promoted to default and a new inference table will be created from the test data. A model monitor to track future inference results will also be created.

In [None]:
PROMOTE_MODEL_VERSION = False

In [None]:
if PROMOTE_MODEL_VERSION:
    m = reg.get_model(qualified_model_name)
    m.default = model_version_nm
    print(f"Model version {model_version_nm} promoted.")
    session.use_schema(MODEL_SCHEMA)
    source_name = f"{INFERENCE_RESULT_TBL_NM}_{MODEL_NAME}_{model_version_nm}"
    base_name = source_name + "_BASELINE"
    table_exist = session.sql(f"SHOW TABLES LIKE '{source_name}';").count() > 0
    if table_exist:
        table_data = session.table(source_name).select(TIME_PERIOD_COLUMN,"GROUP_IDENTIFIER_STRING")
        data_to_save = test_pred_v_actuals.join(table_data, on = [TIME_PERIOD_COLUMN, "GROUP_IDENTIFIER_STRING"], how="leftanti")
        data_to_save.drop("DATASET").write.save_as_table(source_name, mode="append")
    else:
        test_pred_v_actuals.drop("DATASET").write.save_as_table(source_name, mode="overwrite")
    session.sql(f"""
        CREATE OR REPLACE MODEL MONITOR {MODEL_NAME}_{model_version_nm}_MONITOR
        WITH
            MODEL={MODEL_NAME}
            VERSION={model_version_nm}
            FUNCTION=predict
            SOURCE={source_name}
            TIMESTAMP_COLUMN={TIME_PERIOD_COLUMN}
            PREDICTION_SCORE_COLUMNS=(PREDICTED)  
            ACTUAL_SCORE_COLUMNS=(MODEL_TARGET)
            SEGMENT_COLUMNS = (GROUP_IDENTIFIER_STRING)
            WAREHOUSE={SESSION_WH}
            REFRESH_INTERVAL='1 day'
            AGGREGATION_WINDOW='1 day';
    """).collect()
    print("Model monitor created")