# Batch Inference Job

This notebook is designed to be scheduled as a recurring job within Snowflake to perform batch inference using a trained forecasting model. It reads from the Feature Store, applies the model registered in the Model Registry, and appends predictions to a results table.

The notebook tracks which predictions have already been made and only generates new predictions for time periods that haven't been processed yet.


❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

__Prerequisites before running this notebook:__ 

- A model must have been trained and saved to the Model Registry using the **modeling.ipynb** notebook.
- The model should have been evaluated using the **evaluation.ipynb** notebook to ensure prediction quality.
- The Feature Store and Feature View must be set up and accessible.

## Instructions

1. Go to the ___VARS___ cell in the __SETUP__ section below. 
    - Adjust the values of the user constants to match your deployment requirements.
    - Key settings include the model name, inference warehouse, and task schedule.
2. Click ___Run all___ to execute the notebook manually, or schedule it as a Snowflake Notebook Task for recurring batch inference.
    - The notebook will automatically detect new time periods requiring predictions.
    - Predictions are appended to the results table, avoiding duplicate processing.
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️


In [None]:
# Imports
from datetime import datetime
from dateutil.relativedelta import relativedelta
from forecast_model_builder.utils import connect, perform_inference
from snowflake.ml.registry import registry
from snowflake.ml.feature_store import FeatureStore, CreationMode
import snowflake.snowpark.functions as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Establish session
session = connect()
session_db = session.connection.database
session_schema = session.connection.schema
session_wh = session.connection.warehouse
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

Session db.schema: FORECAST_MODEL_BUILDER.TEST
Session warehouse: FORECAST_MODEL_BUILDER_WH


-----
# SETUP
-----


In [None]:
# SET GLOBAL VARIABLES FOR THIS JOB

# --------------------------------
# Prediction Results Storage
# --------------------------------
# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
INFERENCE_RESULT_TBL_NM = "FORECAST_RESULTS"

# --------------------------------
# Model Configuration
# --------------------------------
# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# --------------------------------
# Virtual Warehouse
# --------------------------------
# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

# --------------------------------
# Scheduling Options
# --------------------------------
# If True, the notebook task will refresh on the same schedule as the feature view.
# If False, you must provide a custom schedule via TASK_SCHEDULE.
REFRESH_WITH_FEATURE_VIEW = False

# Schedule for the notebook task (only used if REFRESH_WITH_FEATURE_VIEW is False).
# Examples: "1 day", "1 hour", "USING CRON 0 9 * * * America/Los_Angeles"
TASK_SCHEDULE = "1 day"

# --------------------------------
# Inference Date Range
# --------------------------------
# Inference start date - prevents processing entire history on first run.
# Subsequent runs will automatically start from the last processed date + 1 period.
INFERENCE_START_DATE = "2025-01-01"

-----
# Establish Objects Needed for Inference
-----


In [None]:
# DERIVED OBJECTS

current_dttm = datetime.now()
INFERENCE_START_DATE = datetime.strptime(INFERENCE_START_DATE, "%Y-%m-%d")

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry (uses the default version)
mv = reg.get_model(qualified_model_name).default

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

inf_table_nm = f"{MODEL_SCHEMA}.{INFERENCE_RESULT_TBL_NM}_{MODEL_NAME}_{model_version_nm}"

# -----------------------------------------------------------------------
# Retrieve User Constants from Model Metadata
# -----------------------------------------------------------------------
# These settings were stored during model training and ensure consistency between training and inference.
stored_constants = mv.show_metrics()["user_settings"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]
TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
TARGET_COLUMN = stored_constants["TARGET_COLUMN"]
ROLLUP_FREQUENCY = stored_constants["ROLLUP_FREQUENCY"]
CURRENT_FREQUENCY = stored_constants["CURRENT_FREQUENCY"]
FORECAST_HORIZON = stored_constants["FORECAST_HORIZON"]
FREQUENCY = ROLLUP_FREQUENCY if ROLLUP_FREQUENCY else CURRENT_FREQUENCY

if not USE_CONTEXT:
    MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

# -----------------------------------------------------------------------
# Create Prediction Results Table (if not exists)
# -----------------------------------------------------------------------
# This table stores all inference results with metadata for tracking and auditing.
session.sql(
    f"""
        create table if not exists {INFERENCE_RESULT_TBL_NM} (
            {TIME_PERIOD_COLUMN} TIMESTAMP,
            GROUP_IDENTIFIER VARIANT,
            GROUP_IDENTIFIER_STRING VARCHAR,
            MODEL_NAME VARCHAR(100),
            MODEL_VERSION VARCHAR(100),
            INFERENCE_DTTM TIMESTAMP,
            PREDICTION DOUBLE
        )
        comment = '{query_tag}'
    """
).collect()

# -----------------------------------------------------------------------
# Connect to Feature Store and Get Feature View
# -----------------------------------------------------------------------
# The feature view is retrieved from the model's lineage to ensure the same features used
# during training are used during inference.
fs = FeatureStore(
    session,
    database=session_db,
    name=session_schema,
    default_warehouse=SESSION_WH,
    creation_mode=CreationMode.FAIL_IF_NOT_EXIST,
)

fv = mv.lineage(direction='upstream')[0].lineage(direction='upstream')[0]

Session warehouse:          FORECAST_MODEL_BUILDER_WH

Model Version:              HOT_FISH_1




-----
# Determine Inference Date Range
-----


In [None]:
# -----------------------------------------------------------------------
# Calculate Inference Date Range
# -----------------------------------------------------------------------
# The start date is determined by checking for existing predictions.
# If predictions already exist for this model version, start from the last predicted date + 1 period.
# This ensures idempotency and prevents duplicate predictions.

# Get existing predictions for this model version
result_table = session.table(source_name)

# If predictions already exist, update start date to continue from where we left off
if result_table.count() > 0:
    INFERENCE_START_DATE = result_table.select(
        F.dateadd(
            FREQUENCY,
            F.lit(1),
            F.max(TIME_PERIOD_COLUMN)).alias(TIME_PERIOD_COLUMN)
    ).collect()[0][TIME_PERIOD_COLUMN]

# Calculate end date based on current date plus forecast horizon
INFERENCE_END_DATE = datetime.today() + relativedelta(**{FREQUENCY + "s": FORECAST_HORIZON})

print(f"Inference Start Date:       {INFERENCE_START_DATE}")
print(f"Inference End Date:         {INFERENCE_END_DATE}")

-----
# Retrieve and Filter Inference Data
-----


In [None]:
# -----------------------------------------------------------------------
# Read Features from Feature View
# -----------------------------------------------------------------------
# Read features from the feature view and filter to the inference date range.
sdf = fs.read_feature_view(fv).filter(
    (F.col(TIME_PERIOD_COLUMN) >= INFERENCE_START_DATE)
    & (F.col(TIME_PERIOD_COLUMN) < INFERENCE_END_DATE)
)

# -----------------------------------------------------------------------
# Exclude Already Predicted Records
# -----------------------------------------------------------------------
# Use a left anti-join to filter out any records that have already been predicted.
# This ensures we don't create duplicate predictions if the job runs multiple times.
#sdf = sdf.join(
#    result_table.select("GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN), 
 #   on=["GROUP_IDENTIFIER_STRING", TIME_PERIOD_COLUMN], 
 #   how="leftanti"
#).drop("GROUP_IDENTIFIER")

print(f"Records to predict:         {sdf.count()}")
sdf.show(3)

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"ORDER_TIMESTAMP"    |"TARGET"       |"FEATURE_1"    |"YEAR"  |"MONTH_SIN"         |"MONTH_COS"         |"WEEK_OF_YEAR_SIN"   |"WEEK_OF_YEAR_COS"  |"DAY_OF_WEEK_SUN"  |"DAY_OF_WEEK_MON"  |"DAY_OF_WEEK_TUE"  |"DAY_OF_WEEK_WED"  |"DAY_OF_WEEK_THU"  |"DAY_OF_WEEK_FRI"  |"DAY_OF_WEEK_SAT"  |"DAY_OF_YEAR_SIN"     |"DAY_OF_YEAR_COS"   |"DAYS_SINCE_JAN2020"  |"MODEL_TARGET"  |"GROUP_IDENTIFIER"  |"GROUP_IDENTIFIER_STRING"        |
----------------------------------------------------------------------------------------------------------------------------------------------

-----
# Run Inference
-----


In [None]:
if sdf.count() > 0:
    # -----------------------------------------------------------------------
    # Perform Batch Inference
    # -----------------------------------------------------------------------
    # Use the perform_inference utility to generate predictions using the registered model.
    # The function handles partitioned model execution efficiently.
    result = perform_inference(session, sdf, mv)
    
    # -----------------------------------------------------------------------
    # Add Metadata Columns
    # -----------------------------------------------------------------------
    # Append model name, version, and timestamp for tracking and auditing purposes.
    result = (
        result
        .with_column("INFERENCE_DTTM",F.lit(current_dttm))
        .rename("_PRED_", "PREDICTED")
        .join(
            sdf.select(TIME_PERIOD_COLUMN,"GROUP_IDENTIFIER_STRING","MODEL_TARGET"),
            on=[TIME_PERIOD_COLUMN,"GROUP_IDENTIFIER_STRING"]
        )
    ).select(result_table.columns)
            
    
    print("Inference complete. Sample results:")
    result.show()

    # Append Predictions to Results Table
    # -----------------------------------------------------------------------
    # Append the new predictions to the results table.
    # Using 'append' mode ensures we don't overwrite existing predictions.
    result.write.save_as_table(inf_table_nm, mode='append')
    
    print(f"Predictions saved to: {inf_table_nm}")
    print(f"Inference job completed at: {datetime.now()}")

else:
    print("No new records found, inference not performed")

Inference input data row count: 15750
Number of end partition invocations to expect in the query profile: 1750
----------------------------------------------------------------------------------------------------------------------------------------------------------------
|"ORDER_TIMESTAMP"    |"GROUP_IDENTIFIER"  |"GROUP_IDENTIFIER_STRING"       |"MODEL_NAME"  |"MODEL_VERSION"  |"INFERENCE_DTTM"            |"PREDICTION"        |
----------------------------------------------------------------------------------------------------------------------------------------------------------------
|2025-01-01 00:00:00  |{                   |STORE_ID_2_PRODUCT_ID_5_LEAD_5  |TEST_MODEL_1  |HOT_FISH_1       |2025-10-28 09:58:44.305409  |501.5511779785156   |
|                     |  "LEAD": 5,        |                                |              |                 |                            |                    |
|                     |  "PRODUCT_ID": 5,  |                                |       