# Inference Notebook
Use the model trained in the modeling notebook to make predictions on an inference dataset.

#### NOTE: The user must have an inference dataset available as a table or view in Snowflake before running this notebook.
- If using a __direct multi-step forecasting__ pattern, the inference dataset does not need to contain records for the future datetime points.
- If using a __global modeling__ pattern, the inference dataset must contain records for each future datetime to be forecasted.

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 
## Instructions

1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform feature engineering steps and inference. Predictions will be stored in a Snowflake table.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [1]:
# Imports
import math
from datetime import datetime

from snowflake.ml.registry import registry
from snowflake.ml.feature_store import FeatureStore, CreationMode
from snowflake.snowpark import Window
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

from forecast_model_builder.feature_engineering import (
    apply_functions_in_a_loop,
    expand_datetime,
    recent_rolling_avg,
    roll_up,
    verify_current_frequency,
    verify_valid_rollup_spec,
)
from forecast_model_builder.utils import connect

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Establish session
session = connect(connection_name="default")
session_db = session.connection.database
session_schema = session.connection.schema
session_wh = session.connection.warehouse
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

Session db.schema: FORECAST_MODEL_BUILDER.TEST
Session warehouse: FORECAST_MODEL_BUILDER_WH
Current Datetime: 2025-09-30 12:36:20.473785


-----
# SETUP
-----

In [3]:
# Establish cutoff datetime for the records that will be scored using the forecast model.
# NOTE: For Direct Multi-Step Forecasting, this value will likely be the most recent date, since this pattern does not take in records for future dates as input.
#       For Global Modeling, this value will be the first future datetime value in the records that will be scored.
# NOTE: This cutoff will be applied AFTER feature engineering, so that lag features can be calculated.
INFERENCE_START_TIMESTAMP = "2025-01-01 00:00:00.000"

# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
# NOTE: Currently the code will overwrite the existing predictions table with the predictions from this run.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# Input data for inference
INFERENCE_DB = "FORECAST_MODEL_BUILDER"
INFERENCE_SCHEMA = "BASE"
INFERENCE_FV = "FORECAST_FEATURES"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

-----
# Establish objects needed for this run
-----

In [4]:
# Derived Objects

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).default

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]

TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
PARTITION_COLUMNS = stored_constants["PARTITION_COLUMNS"]
EXOGENOUS_COLUMNS = stored_constants["EXOGENOUS_COLUMNS"]
ALL_EXOG_COLS_HAVE_FUTURE_VALS = stored_constants["ALL_EXOG_COLS_HAVE_FUTURE_VALS"]
CREATE_LAG_FEATURE = stored_constants["CREATE_LAG_FEATURE"]
CURRENT_FREQUENCY = stored_constants["CURRENT_FREQUENCY"]
ROLLUP_FREQUENCY = stored_constants["ROLLUP_FREQUENCY"]
ROLLUP_AGGREGATIONS = stored_constants["ROLLUP_AGGREGATIONS"]
FORECAST_HORIZON = stored_constants["FORECAST_HORIZON"]
INFERENCE_APPROX_BATCH_SIZE = stored_constants["INFERENCE_APPROX_BATCH_SIZE"]
MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

# -----------------------------------------------------------------------
# Create a window spec
# -----------------------------------------------------------------------
window_spec = Window.partitionBy(PARTITION_COLUMNS).orderBy(TIME_PERIOD_COLUMN)

# -----------------------------------------------------------------------
# Create a variable to hold the granularity at which we will be modeling
# -----------------------------------------------------------------------
CURRENT_FREQUENCY = CURRENT_FREQUENCY.lower()

if ROLLUP_FREQUENCY is not None:
    ROLLUP_FREQUENCY = ROLLUP_FREQUENCY.lower()
    if ROLLUP_FREQUENCY.lower() == "none":
        ROLLUP_FREQUENCY = None

modeling_frequency = CURRENT_FREQUENCY if ROLLUP_FREQUENCY is None else ROLLUP_FREQUENCY
print(f"Modeling Frequency:         {modeling_frequency}")

# -----------------------------------------------------------------------
# Establish modeling pattern
# -----------------------------------------------------------------------
# Either (1) train_separate_lead_models = False : all features have future values in the inference data, so we don't need a separate model for each lead
# or (2) train_separate_lead_models = True : data contains exogenous variables that the inference data won't have future values for, requiring lead modeling
train_separate_lead_models = (
    False
    if ALL_EXOG_COLS_HAVE_FUTURE_VALS is True or len(EXOGENOUS_COLUMNS) == 0
    else True
)
print(f"Train Separate Lead Models: {train_separate_lead_models}")

# --------------------------------
# Get feature view
# --------------------------------

fs = FeatureStore(
    session,
    database=session_db,
    name=session_schema,
    default_warehouse=SESSION_WH,
    creation_mode=CreationMode.FAIL_IF_NOT_EXIST,
)

fv = mv.lineage(direction='upstream')[0]

sdf = fs.read_feature_view(fv).cache_result()

Session warehouse:          FORECAST_MODEL_BUILDER_WH

Model Version:              HOT_STARFISH_1
Modeling Frequency:         day
Train Separate Lead Models: False




-----
# Inference
-----

In [5]:
# ------------------------------------------------------------------------
# INFERENCE
# ------------------------------------------------------------------------
# Prep data set for inference

# If the inference dataset does not have the TARGET column already, add it and fill it with null values
if TARGET_COLUMN not in sdf.columns:
    sdf = sdf.with_column(TARGET_COLUMN, F.lit(None).cast(T.FloatType()))

inference_input_df = sdf.filter(
    F.col(TIME_PERIOD_COLUMN) >= INFERENCE_START_TIMESTAMP
).drop("GROUP_IDENTIFIER")

model_bytes_table = (
    session.table(MODEL_BINARY_STORAGE_TBL_NM)
    .filter(F.col("MODEL_NAME") == MODEL_NAME)
    .filter(F.col("MODEL_VERSION") == model_version_nm)
    .select("GROUP_IDENTIFIER_STRING", "MODEL_BINARY")
)

# NOTE: We inner joint to the model bytes table to ensure that we only try run inference on partitions that have a model.
inference_input_df = inference_input_df.join(
    model_bytes_table, on=["GROUP_IDENTIFIER_STRING"], how="inner"
)

# Add a column called BATCH_GROUP,
#   which has the property that for each unique value there are roughly the number of records specified in batch_size.
# Use that to create a PARTITION_ID column that will be used to run inference in batches.
# We do this to avoid running out of memory when performing inference on a large number of records.
largest_partition_record_count = (
    inference_input_df.group_by("GROUP_IDENTIFIER_STRING")
    .agg(F.count("*").alias("PARTITION_RECORD_COUNT"))
    .agg(F.max("PARTITION_RECORD_COUNT").alias("MAX_PARTITION_RECORD_COUNT"))
    .collect()[0]["MAX_PARTITION_RECORD_COUNT"]
)
batch_size = INFERENCE_APPROX_BATCH_SIZE
number_of_batches = math.ceil(largest_partition_record_count / batch_size)
inference_input_df = (
    inference_input_df.with_column(
        "BATCH_GROUP", F.abs(F.random(123)) % F.lit(number_of_batches)
    )
    .with_column(
        "PARTITION_ID",
        F.concat_ws(
            F.lit("__"), F.col("GROUP_IDENTIFIER_STRING"), F.col("BATCH_GROUP")
        ),
    )
    .drop("RANDOM_NUMBER", "BATCH_GROUP")
)

# Look at a couple rows of the inference input data
print(f"Inference input data row count: {inference_input_df.count()}")
print(
    f"Number of end partition invocations to expect in the query profile: {inference_input_df.select('PARTITION_ID').distinct().count()}"
)

# Perform inference from the model registry
session.use_warehouse(INFERENCE_WH)

# Use the model to score the input data
inference_result = mv.run(inference_input_df, partition_column="PARTITION_ID").select(
    "_PRED_",
    F.col("GROUP_IDENTIFIER_STRING_OUT_").alias("GROUP_IDENTIFIER_STRING"),
    F.col(f"{TIME_PERIOD_COLUMN}_OUT_").alias(TIME_PERIOD_COLUMN),
)

print("Predictions")
inference_result.show(2)
session.use_warehouse(SESSION_WH)

Inference input data row count: 2250
Number of end partition invocations to expect in the query profile: 250
Predictions
-----------------------------------------------------------------------
|"_PRED_"           |"GROUP_IDENTIFIER_STRING"  |"ORDER_TIMESTAMP"    |
-----------------------------------------------------------------------
|399.9723205566406  |STORE_ID_12_PRODUCT_ID_5   |2025-01-01 00:00:00  |
|407.5075378417969  |STORE_ID_12_PRODUCT_ID_5   |2025-01-02 00:00:00  |
-----------------------------------------------------------------------



In [6]:
# Write predictions to a Snowflake table.
# Right now this is set up to overwrite the table if it already exists.
inference_result.write.save_as_table(
    INFERENCE_RESULT_TBL_NM,
    mode="overwrite",
    comment='{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}',
)

print(
    f"Predictions written to table: {session_db}.{session_schema}.{INFERENCE_RESULT_TBL_NM}"
)

# Look at a few rows of the snowflake table
print("Sample records:")
session.table(INFERENCE_RESULT_TBL_NM).limit(3)

Predictions written to table: FORECAST_MODEL_BUILDER.TEST.FORECAST_RESULTS
Sample records:


<snowflake.snowpark.dataframe.DataFrame at 0x31e093b00>