# Inference Notebook
Use the model trained in the modeling notebook to make predictions on an inference dataset.

#### NOTE: The user must have an inference dataset available as a table or view in Snowflake before running this notebook.
- If using a __direct multi-step forecasting__ pattern, the inference dataset does not need to contain records for the future datetime points.
- If using a __global modeling__ pattern, the inference dataset must contain records for each future datetime to be forecasted.

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 
## Instructions

1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform feature engineering steps and inference. Predictions will be stored in a Snowflake table.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [34]:
# Imports
import math
from datetime import datetime

from snowflake.ml.registry import registry
from snowflake.ml.dataset import Dataset
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
from snowflake.snowpark import DataFrame as SnowparkDataFrame
import streamlit as st

from forecast_model_builder.feature_engineering import (
    apply_functions_in_a_loop,
    expand_datetime,
    recent_rolling_avg,
    roll_up,
    verify_current_frequency,
    verify_valid_rollup_spec,
)
from forecast_model_builder.utils import connect

In [2]:
# Establish session
session = connect(connection_name="default")
session_db = session.connection.database
session_schema = session.connection.schema
session_wh = session.connection.warehouse
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

Session db.schema: FORECAST_MODEL_BUILDER.TEST
Session warehouse: FORECAST_MODEL_BUILDER_WH
Current Datetime: 2025-10-13 14:41:15.751743


-----
# SETUP
-----

In [3]:
# Establish cutoff datetime for the records that will be scored using the forecast model.
# NOTE: For Direct Multi-Step Forecasting, this value will likely be the most recent date, since this pattern does not take in records for future dates as input.
#       For Global Modeling, this value will be the first future datetime value in the records that will be scored.
# NOTE: This cutoff will be applied AFTER feature engineering, so that lag features can be calculated.

# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
# NOTE: Currently the code will overwrite the existing predictions table with the predictions from this run.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# Input data for inference
INFERENCE_DB = "FORECAST_MODEL_BUILDER"
INFERENCE_SCHEMA = "BASE"
INFERENCE_FV = "FORECAST_FEATURES"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

-----
# Establish objects needed for this run
-----

In [24]:
# Derived Objects

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).default

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]

TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
PARTITION_COLUMNS = stored_constants["PARTITION_COLUMNS"]
EXOGENOUS_COLUMNS = stored_constants["EXOGENOUS_COLUMNS"]
ALL_EXOG_COLS_HAVE_FUTURE_VALS = stored_constants["ALL_EXOG_COLS_HAVE_FUTURE_VALS"]
CREATE_LAG_FEATURE = stored_constants["CREATE_LAG_FEATURE"]
CURRENT_FREQUENCY = stored_constants["CURRENT_FREQUENCY"]
ROLLUP_FREQUENCY = stored_constants["ROLLUP_FREQUENCY"]
ROLLUP_AGGREGATIONS = stored_constants["ROLLUP_AGGREGATIONS"]
FORECAST_HORIZON = stored_constants["FORECAST_HORIZON"]
INFERENCE_APPROX_BATCH_SIZE = stored_constants["INFERENCE_APPROX_BATCH_SIZE"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]

# --------------------------------
# Get datasets
# --------------------------------

def load_df_from_ds(fully_qualified_name, version):
    ds_db, ds_schema, ds_name = fully_qualified_name.split('.')

    return Dataset(
        session=session,
        database=ds_db,
        schema=ds_schema,
        name=ds_name,
        selected_version=version
    ).read.to_snowpark_dataframe()

train_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['train_dataset']['name'],
    version=mv.show_metrics()['train_dataset']['version']
).drop("GROUP_IDENTIFIER")

test_df = load_df_from_ds(
    fully_qualified_name=mv.show_metrics()['test_dataset']['name'],
    version=mv.show_metrics()['test_dataset']['version']
).drop("GROUP_IDENTIFIER")


Session warehouse:          FORECAST_MODEL_BUILDER_WH
Inference warehouse:        FORECAST_MODEL_BUILDER_WH 

Model Version:              LIGHT_HOUND_1


-----
# Inference
-----

In [26]:
# ------------------------------------------------------------------------
# INFERENCE
# ------------------------------------------------------------------------

def perform_inference(inference_input_df):
    # If the inference dataset does not have the TARGET column already, add it and fill it with null values
    if TARGET_COLUMN not in inference_input_df.columns:
        inference_input_df = inference_input_df.with_column(TARGET_COLUMN, F.lit(None).cast(T.FloatType()))

    if not USE_CONTEXT:
        MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]
        model_bytes_table = (
            session.table(MODEL_BINARY_STORAGE_TBL_NM)
            .filter(F.col("MODEL_NAME") == MODEL_NAME)
            .filter(F.col("MODEL_VERSION") == model_version_nm)
            .select("GROUP_IDENTIFIER_STRING", "MODEL_BINARY")
        )

        # NOTE: We inner joint to the model bytes table to ensure that we only try run inference on partitions that have a model.
        inference_input_df = inference_input_df.join(
            model_bytes_table, on=["GROUP_IDENTIFIER_STRING"], how="inner"
        )

    # Add a column called BATCH_GROUP,
    #   which has the property that for each unique value there are roughly the number of records specified in batch_size.
    # Use that to create a PARTITION_ID column that will be used to run inference in batches.
    # We do this to avoid running out of memory when performing inference on a large number of records.
    largest_partition_record_count = (
        inference_input_df.group_by("GROUP_IDENTIFIER_STRING")
        .agg(F.count("*").alias("PARTITION_RECORD_COUNT"))
        .agg(F.max("PARTITION_RECORD_COUNT").alias("MAX_PARTITION_RECORD_COUNT"))
        .collect()[0]["MAX_PARTITION_RECORD_COUNT"]
    )
    batch_size = INFERENCE_APPROX_BATCH_SIZE
    number_of_batches = math.ceil(largest_partition_record_count / batch_size)
    inference_input_df = (
        inference_input_df.with_column(
            "BATCH_GROUP", F.abs(F.random(123)) % F.lit(number_of_batches)
        )
        .with_column(
            "PARTITION_ID",
            F.concat_ws(
                F.lit("__"), F.col("GROUP_IDENTIFIER_STRING"), F.col("BATCH_GROUP")
            ),
        )
        .drop("RANDOM_NUMBER", "BATCH_GROUP")
    )

    # Look at a couple rows of the inference input data
    print(f"Inference input data row count: {inference_input_df.count()}")
    print(
        f"Number of end partition invocations to expect in the query profile: {inference_input_df.select('PARTITION_ID').distinct().count()}"
    )
    # Use the model to score the input data
    inference_result = mv.run(inference_input_df, partition_column="PARTITION_ID").select(
        "_PRED_",
        F.col("GROUP_IDENTIFIER_STRING_OUT_").alias("GROUP_IDENTIFIER_STRING"),
        F.col(f"{TIME_PERIOD_COLUMN}_OUT_").alias(TIME_PERIOD_COLUMN),
    )

    return inference_result


print("Predictions")
session.use_warehouse(INFERENCE_WH)

train_result = perform_inference(train_df).with_column("DATASET",F.lit("TRAIN"))
test_result = perform_inference(test_df).with_column("DATASET",F.lit("TEST"))
test_result.show(2)

session.use_warehouse(SESSION_WH)

Predictions
Inference input data row count: 344750
Number of end partition invocations to expect in the query profile: 1750
Inference input data row count: 22750
Number of end partition invocations to expect in the query profile: 250
------------------------------------------------------------------------------------
|"_PRED_"            |"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"DATASET"  |
------------------------------------------------------------------------------------
|145.43472290039062  |STORE_ID_6_PRODUCT_ID_6    |2024-12-06 00:00:00  |TEST       |
|141.02732849121094  |STORE_ID_6_PRODUCT_ID_6    |2024-12-17 00:00:00  |TEST       |
------------------------------------------------------------------------------------



In [28]:
# Write predictions to a Snowflake table.
# Right now this is set up to overwrite the table if it already exists.

inference_result = train_result.union_all_by_name(test_result)
inference_result.write.save_as_table(
    INFERENCE_RESULT_TBL_NM,
    mode="overwrite",
    comment='{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}',
)

print(
    f"Predictions written to table: {session_db}.{session_schema}.{INFERENCE_RESULT_TBL_NM}"
)

# Look at a few rows of the snowflake table
print("Sample records:")
inference_result = session.table(INFERENCE_RESULT_TBL_NM)
inference_result.limit(3).show()

Predictions written to table: FORECAST_MODEL_BUILDER.TEST.FORECAST_RESULTS
Sample records:
-----------------------------------------------------------------------------------
|"_PRED_"           |"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"DATASET"  |
-----------------------------------------------------------------------------------
|562.4960327148438  |STORE_ID_25_PRODUCT_ID_8   |2023-10-02 00:00:00  |TRAIN      |
|588.8887939453125  |STORE_ID_25_PRODUCT_ID_8   |2023-07-25 00:00:00  |TRAIN      |
|577.991455078125   |STORE_ID_25_PRODUCT_ID_8   |2023-07-26 00:00:00  |TRAIN      |
-----------------------------------------------------------------------------------



In [None]:
sdf = train_df.union_all_by_name(test_df).select("GROUP_IDENTIFIER_STRING",TIME_PERIOD_COLUMN,TARGET_COLUMN)

pred_v_actuals = (
    inference_result
    .join(sdf, on=["GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN])
    .select(
        "GROUP_IDENTIFIER_STRING", 
        TIME_PERIOD_COLUMN, 
        TARGET_COLUMN,
        F.col("_PRED_").alias("PREDICTED")
    )
)

inference_partition_count = pred_v_actuals.select("GROUP_IDENTIFIER_STRING").distinct().count()

training_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TRAIN")

test_pred_v_actuals = pred_v_actuals.filter(F.col("DATASET")=="TEST")
print(f"Dataset has {inference_partition_count} partitions")
pred_v_actuals.show(3)

Test dataset has 250 partitions
-------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |"TARGET"           |"PREDICTED"        |
-------------------------------------------------------------------------------------------
|STORE_ID_9_PRODUCT_ID_8    |2023-07-30 00:00:00  |947.8779296875     |941.7086181640625  |
|STORE_ID_9_PRODUCT_ID_8    |2023-05-26 00:00:00  |887.1766967773438  |888.5808715820312  |
|STORE_ID_9_PRODUCT_ID_8    |2023-06-01 00:00:00  |880.4500122070312  |885.3113403320312  |
-------------------------------------------------------------------------------------------



In [30]:
total_window = Window.partition_by()

partition_weights = (
    pred_v_actuals
    .group_by("GROUP_IDENTIFIER_STRING")
    .agg(F.sum(TARGET_COLUMN).alias(f'PARTITION_{TARGET_COLUMN}_SUM'), F.min(TIME_PERIOD_COLUMN), F.max(TIME_PERIOD_COLUMN))
    .with_column(f"TOTAL_{TARGET_COLUMN}", F.sum(f'PARTITION_{TARGET_COLUMN}_SUM').over(total_window))
    .with_column("PARTITION_WEIGHT", F.col(f'PARTITION_{TARGET_COLUMN}_SUM')/F.col(f"TOTAL_{TARGET_COLUMN}"))
)

partition_weights.sort(F.col("PARTITION_WEIGHT").desc()).show()

-----------------------------------------------------------------------------------------------------------------------------------------------------
|"GROUP_IDENTIFIER_STRING"  |"PARTITION_TARGET_SUM"  |"MIN(ORDER_TIMESTAMP)"  |"MAX(ORDER_TIMESTAMP)"  |"TOTAL_TARGET"      |"PARTITION_WEIGHT"     |
-----------------------------------------------------------------------------------------------------------------------------------------------------
|STORE_ID_17_PRODUCT_ID_1   |1456821.035583496       |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.00795719234225625    |
|STORE_ID_24_PRODUCT_ID_6   |1438105.7832641602      |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.007854969173588749   |
|STORE_ID_22_PRODUCT_ID_3   |1411616.508605957       |2021-01-01 00:00:00     |2025-01-09 00:00:00     |183082294.97571963  |0.007710284103622175   |
|STORE_ID_21_PRODUCT_ID_9   |1404011.4907226562      |2021-01-01 00:00:00     |2025-01-09 00:00:00  

In [None]:
def produce_metrics(sdf: SnowparkDataFrame) -> SnowparkDataFrame:
    # Row-level metrics
    row_actual_v_fcst = (
        sdf
        .with_column("PRED_ERROR", F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        .with_column(
            "ABS_ERROR", F.abs(F.col(TARGET_COLUMN) - F.col("PREDICTED"))
        )
        .with_column(
            "APE",
            F.when(F.col(TARGET_COLUMN) == 0, F.lit(None)).otherwise(
                F.abs(F.col("ABS_ERROR") / F.col(TARGET_COLUMN))
            ),
        )
        .with_column("SQ_ERROR", F.pow(F.col(TARGET_COLUMN) - F.col("PREDICTED"), 2))
    )
    
    # Metrics per partition
    partition_metrics = row_actual_v_fcst.group_by("GROUP_IDENTIFIER_STRING").agg(
        F.avg("APE").alias("MAPE"),
        F.avg("ABS_ERROR").alias("MAE"),
        F.sqrt(F.avg("SQ_ERROR")).alias("RMSE"),
        F.count("*").alias("TOTAL_PRED_COUNT"),
    )
    
    # Overall modeling process across all partitions
    overall_avg_metrics = partition_metrics.agg(
        F.avg("MAPE").alias("OVERALL_MAPE"),
        F.avg("MAE").alias("OVERALL_MAE"),
        F.avg("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("AVG"))
    
    overall_weighted_avg_metrics = (
        partition_metrics
            .join(partition_weights.select("GROUP_IDENTIFIER_STRING", "PARTITION_WEIGHT"), on=["GROUP_IDENTIFIER_STRING"])
            .agg(
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAPE")).alias("OVERALL_MAPE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("MAE")).alias("OVERALL_MAE"),
                F.sum(F.col("PARTITION_WEIGHT")*F.col("RMSE")).alias("OVERALL_RMSE"),
                 )
            .with_column("AGGREGATION", F.lit("WEIGHTED_AVG"))
    )
    
    overall_median_metrics = partition_metrics.agg(
        F.median("MAPE").alias("OVERALL_MAPE"),
        F.median("MAE").alias("OVERALL_MAE"),
        F.median("RMSE").alias("OVERALL_RMSE"),
    ).with_column("AGGREGATION", F.lit("MEDIAN"))
    
    overall_metrics = (
        overall_avg_metrics
            .union(overall_median_metrics)
            .union(overall_weighted_avg_metrics)
            .select("AGGREGATION", "OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE")
            .sort("AGGREGATION")
    )
    
    # Show the metrics
    if inference_partition_count == 1:
        st.write(
            "There is only 1 partition, so these values are the metrics for that single model:"
        )
        st.dataframe(
            overall_median_metrics.select("OVERALL_MAPE", "OVERALL_MAE", "OVERALL_RMSE"),
            use_container_width=True,
        )
    else:
        st.write("Avg and Median of each metric over all the partitions:")
        st.dataframe(overall_metrics, use_container_width=True)

    return row_actual_v_fcst, partition_metrics

st.write("TRAINING SET")
train_metric_sdf, train_partition_metrics = produce_metrics(training_pred_v_actuals)
st.write("VALIDATION SET")
test_metric_sdf, test_partition_metrics = produce_metrics(test_pred_v_actuals)

2025-10-13 15:44:05.581 
  command:

    streamlit run /opt/anaconda3/envs/forecast/lib/python3.12/site-packages/ipykernel_launcher.py [ARGUMENTS]
2025-10-13 15:44:07.224 Please replace `use_container_width` with `width`.

`use_container_width` will be removed after 2025-12-31.

For `use_container_width=True`, use `width='stretch'`. For `use_container_width=False`, use `width='content'`.
2025-10-13 15:44:18.480 Please replace `use_container_width` with `width`.

`use_container_width` will be removed after 2025-12-31.

For `use_container_width=True`, use `width='stretch'`. For `use_container_width=False`, use `width='content'`.
