# Partitioned Time Series Modeling

This notebook can be used to train a time series forecasting model. 

It is especially useful for use cases in which multiple series need to be trained in parallel. For example, if a retailer needs to build a separate model for each individual store location, this code will train those models in parallel. This greatly improves run time, especially in cases involving a large number of partitions. 

❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

__Prerequisites before running this notebook:__ 

- The time series data in Snowflake must have at least the following columns: __date or timestamp__ column, __target__ column, and __partition__ column(s) if multi-series. 
- The date column name MUST be in [unquoted identifier](https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-unquoted-identifier) format, i.e. contains only __upper case letters, underscores, and decimal digits__. It is __recommended__ that all other column names also be in that format so that [double-quoted identifiers](https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-delimited-identifier) are not needed.
- The target column (and any exogenous feature columns if they exist) should have values in a numeric format like FLOAT, DOUBLE, or INT. 
- Any null values in the data should already be imputed. 

## Instructions


1. Go to the ____set_global_variables___ cell in the __SETUP__ section below. 
    - Change the values of the user constants to match the specifications of the use case.
    - Descriptions of each value are written in that cell.
2. Click ___Run all___ in the upper right corner of the notebook to run the entire notebook. 
    - The notebook will perform feature engineering and will train models. 
    - If ___SAVE_MODEL_VERSION_THIS_RUN=True___, then the models will be saved to the model registry for later inference. 
    
❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ ❄️ 

In [None]:
# Imports
import importlib.metadata
import json
import tempfile
import os
import pickle
import pkgutil
import random
from datetime import datetime
from typing import Optional

import pandas as pd
import xgboost as xgb
from snowflake.ml.model import custom_model
from snowflake.ml.dataset import Dataset
from snowflake.ml.registry import registry
from snowflake.snowpark import Window
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode,
)

from forecast_model_builder.feature_engineering import (
    apply_functions_in_a_loop,
    expand_datetime,
    recent_rolling_avg,
    roll_up,
    verify_current_frequency,
    verify_valid_rollup_spec,
)
from forecast_model_builder.utils import (
    connect,
    version_featureview,
    version_data,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Establish session
session = connect(connection_name="default")
session_db = session.connection.database
session_schema = session.connection.schema
session_db_schema = f"{session_db}.{session_schema}"
print(f"Session db.schema: {session_db_schema}")

# Query tag
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"modeling"}}'
session.query_tag = query_tag

# Get the current datetime  (This will be saved in the model storage table)
run_dttm = datetime.now()
print(f"Current Datetime: {run_dttm}")

Session db.schema: FORECAST_MODEL_BUILDER.TEST
Current Datetime: 2025-10-15 14:40:50.100966


-----
# SETUP
-----

In [3]:
# SET GLOBAL VARIABLES FOR THIS RUN

# Name the model (if model already exists, a new version will be created)
MODEL_NAME = "TEST_MODEL_1"

# Boolean that is True if we want to save the model version in the current run.
# Users may want to set it to false while they experiment with different specificiations, and then set to True when they develop the model they want to save.
SAVE_MODEL_VERSION_THIS_RUN = True

# --------------------------------
# Input Time Series Data
# --------------------------------
# Establish the Snowflake database, schema, and table containing the time series data
TS_DB = "FORECAST_MODEL_BUILDER"
TS_SCHEMA = "BASE"
TS_TABLE_NM = "DAILY_PARTITIONED_SAMPLE_DATA"

# --------------------------------
# Modeling setup
# --------------------------------
# Establish the Database and Schema that will be used to store the models
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"

# --------------------------------
# Virtual Warehouse
# --------------------------------
# For modeling and inference, a larger warehouse may speed up execution time depending on the number of partitions.
# Scale up if there are a lot of partitions.
# NOTE: If set to None, then the session warehouse will be used.
MODELING_WH = "STANDARD_XL"

# --------------------------------
# Modeling
# --------------------------------
# From the time series data (TS_TABLE_NM), specify the name of column containing the datetime information
TIME_PERIOD_COLUMN = "ORDER_TIMESTAMP"

# NOTE: For the next 3 constants, if column names require a double-quoted identifier, include double quotes within the single quotes.
#       Examples: '"Target"', ['"store id"', '"product id"'], ['"Feature 1"'].

# Name of column containing the target variable (i.e. the value we are trying to predict)
TARGET_COLUMN = "TARGET"

# List of column names to use as partition columns. This is how you define each individual series to be modeled.
# If modeling a single series (i.e. no partitions) set this as an EMPTY LIST [].
PARTITION_COLUMNS = ["STORE_ID", "PRODUCT_ID"]

# List of column names in the time series table to use as EXOGENOUS FEATURES.
# Exogenous features are variables outside the main time series that can impact future values of the target variable.
#     Examples: weather features, promotions, holidays, economic indicators (like inflation), inventory on hand, etc.
# If there are no exogenous features in the data set, set this as an EMPTY LIST [].
# NOTE: This notebook will create several features (like YEAR, MONTH, DAY_OF_YEAR, etc). You do NOT need to list those.
#       Only list features that are already in the Snowflake table (TS_TABLE_NM).
EXOGENOUS_COLUMNS = ["FEATURE_1"]

# ALL_EXOG_COLS_HAVE_FUTURE_VALS is a boolean that is True if all exogenous features have future values present in the inference data.
#     For example, if you are predicting 56 days into the future (i.e. FORECAST_HORIZON=56),
#                  but you only know promotions for the next 4 weeks, you would set ALL_EXOG_COLS_HAVE_FUTURE_VALS = False.
# NOTE: There are two modeling patterns in this notebook:
#       1. Direct Multi-Step Forecasting - If ALL_EXOG_COLS_HAVE_FUTURE_VALS = False,
#                                           the code will create separate models for each lead/step (from step = 1 to step = FORECAST_HORIZON) within each partition.
#                                           In this pattern, inference is done using the most current date's information to predict each future step.
#       2. Global Modeling               - If ALL_EXOG_COLS_HAVE_FUTURE_VALS = True (or EXOGENOUS_COLUMNS is empty),
#                                           the code will train a single model within each partition.
#                                           In this pattern, inference is done using the information for each future step,
#                                               so the inference dataset will need a separate record for each future date to be predicted.
# This variable will determine which pattern is used.
ALL_EXOG_COLS_HAVE_FUTURE_VALS = True

# Specify if we should create lag features for the target variable (including avgs of previous periods). This will affect the lag_and_target_prep & recent_rolling_avg functions.
# NOTE: If we are using the Global Modeling pattern, we will not create recent rolling avg features.
CREATE_LAG_FEATURE = False

# Frequency of the data (choose from: "second", "minute", "hour", "day", "week", "month", "other")
# This is the frequency of the data as it currently exists in the Snowflake table (TS_TABLE_NM).
# If it is not a standard frequency, select "other"
CURRENT_FREQUENCY = "day"

# Frequency to roll up to (choose from: "second", "minute", "hour", "day", "week", "month", or None)
# If you do not wish to roll up to a higher level, set ROLLUP_FREQUENCY=None.
ROLLUP_FREQUENCY = None

# Specify how each column should agg on roll-up (choose from: "sum", "avg", "min", or "max")
# NOTE: If rollup_frequency is not None, then this can be an empty dictionary {}.
#       Otherwise, you must specify an aggregation for the TARGET column AND for each of the EXOGENOUS_COLUMNS.
ROLLUP_AGGREGATIONS = {
    TARGET_COLUMN: "sum",
    "FEATURE_1": "sum",
}

# Forecast Horizon. Number of time periods to forecast into the future (UNITS will be that of the ROLLUP_FREQUENCY if specified, otherwise CURRENT_FREQUENCY).
# NOTE: Keep this number as small as possible if doing Direct Multi-Step Forecasting (in which a separate model gets built for each future time period).
FORECAST_HORIZON = 7

# Specify how many days to set aside for validation.
# NOTE: If this is set to 0, then the model will be trained on all historic data.
# Setting aside testing data is required to run the subsequent evaluation notebook.
VALIDATION_DAYS = 90

# XGBRegressor hyperparameter selections.
# NOTE: This notebook does not perform hyperparameter tuning, so you can set these parameters here if you know which values you would like to use.
XGB_PARAMS = {
    "learning_rate": 0.05,
    "subsample": 0.80,
    "colsample_bytree": 0.80,
    "random_state": 42,
}

# --------------------------------
# Inference
# --------------------------------
# When distributing the inference records, we can set the batch size here.
# If the number is too high, inference on a large number of records might use up all available memory.
INFERENCE_APPROX_BATCH_SIZE = 200

# --------------------------------
# Calculated Constants
# --------------------------------
# Model context can currently only accept a maximum of 1000 models. Solutions with > 1000 models will use a model storage table instead.
MODELCONTEXT_MAX = 1000
# Establish the name of the table that will hold model binaries.
# This will be a Snowflake table in your project schema if the number of models is > MODELCONTEXT_MAX
MODEL_BINARY_STORAGE_TBL_NM = f"MODEL_STORAGE_{MODEL_NAME}"



-----
# Establish objects needed for this run
-----

In [None]:
# DERIVED OBJECTS

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Modeling Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as MODELING_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if MODELING_WH in available_warehouses:
    print(f"Modeling warehouse:         {MODELING_WH} \n")
else:
    print(
        f"WARNING: User does not have access to MODELING_WH = '{MODELING_WH}'. Model training will use '{SESSION_WH}' instead. \n"
    )
    MODELING_WH = SESSION_WH


# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Create dictionary of user settings to log with the model
# -----------------------------------------------------------------------
user_settings_dict = {
    "MODEL_NAME": MODEL_NAME,
    "SAVE_MODEL_VERSION_THIS_RUN": SAVE_MODEL_VERSION_THIS_RUN,
    "TS_DB": TS_DB,
    "TS_SCHEMA": TS_SCHEMA,
    "TS_TABLE_NM": TS_TABLE_NM,
    "MODEL_DB": MODEL_DB,
    "MODEL_SCHEMA": MODEL_SCHEMA,
    "SESSION_WH": SESSION_WH,
    "MODELING_WH": MODELING_WH,
    "TIME_PERIOD_COLUMN": TIME_PERIOD_COLUMN,
    "PARTITION_COLUMNS": PARTITION_COLUMNS,
    "EXOGENOUS_COLUMNS": EXOGENOUS_COLUMNS,
    "ALL_EXOG_COLS_HAVE_FUTURE_VALS": ALL_EXOG_COLS_HAVE_FUTURE_VALS,
    "CREATE_LAG_FEATURE": CREATE_LAG_FEATURE,
    "CURRENT_FREQUENCY": CURRENT_FREQUENCY,
    "ROLLUP_FREQUENCY": ROLLUP_FREQUENCY,
    "ROLLUP_AGGREGATIONS": ROLLUP_AGGREGATIONS,
    "FORECAST_HORIZON": FORECAST_HORIZON,
    "VALIDATION_DAYS": VALIDATION_DAYS,
    "XGB_PARAMS": XGB_PARAMS,
    "INFERENCE_APPROX_BATCH_SIZE": INFERENCE_APPROX_BATCH_SIZE,
}

# -----------------------------------------------------------------------
# BACKEND SETUP: Create Model Schema
# -----------------------------------------------------------------------
# Create a schema to hold our models if it does not already exist
schema_exists = (
    session.table(f"{MODEL_DB}.INFORMATION_SCHEMA.SCHEMATA")
    .filter(F.upper(F.col("SCHEMA_NAME")) == F.upper(F.lit(MODEL_SCHEMA)))
    .count()
)

if schema_exists == 0:
    try:
        session.sql(f"create schema if not exists {MODEL_DB}.{MODEL_SCHEMA}").collect()
    except Exception as e:
        if "insufficient privileges" in str(e).lower():
            raise PermissionError(f"""Schema {MODEL_SCHEMA} does not already exist in {MODEL_DB}, and user does not have sufficient privileges to CREATE SCHEMA. 
            Please specify an existing schema for MODEL_SCHEMA constant.""") from e
        else:
            raise RuntimeError(
                f"An error occurred while attempting to create schema {MODEL_DB}.{MODEL_SCHEMA}: {e}"
            ) from e

# Reset the schema to the original session schema. (If we created a new schema, the session schema was set to the new schema)
session.use_schema(session_db_schema)

# -----------------------------------------------------------------------
# Create a window spec
# -----------------------------------------------------------------------
window_spec = Window.partitionBy(PARTITION_COLUMNS).orderBy(TIME_PERIOD_COLUMN)

# -----------------------------------------------------------------------
# Create a variable for the frequency at which we will be modeling
# -----------------------------------------------------------------------
CURRENT_FREQUENCY = CURRENT_FREQUENCY.lower()

if ROLLUP_FREQUENCY is not None:
    ROLLUP_FREQUENCY = ROLLUP_FREQUENCY.lower()
    if ROLLUP_FREQUENCY.lower() == "none":
        ROLLUP_FREQUENCY = None

modeling_frequency = CURRENT_FREQUENCY if ROLLUP_FREQUENCY is None else ROLLUP_FREQUENCY
print(f"Modeling Frequency:         {modeling_frequency}")

# -----------------------------------------------------------------------
# Varible for modeling pattern
# -----------------------------------------------------------------------
# Either (1) train_separate_lead_models = False : all features have future values in the inference data, so we don't need a separate model for each lead
# or (2) train_separate_lead_models = True : data contains exogenous variables that the inference data won't have future values for, requiring direct multi-step (lead) modeling
train_separate_lead_models = (
    False
    if ALL_EXOG_COLS_HAVE_FUTURE_VALS is True or len(EXOGENOUS_COLUMNS) == 0
    else True
)
print(f"Train Separate Lead Models: {train_separate_lead_models}")

# -----------------------------------------------------------------------
# Establish model registry object
# -----------------------------------------------------------------------
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# -----------------------------------------------------------------------
# Does model already exist in the registry?
# -----------------------------------------------------------------------
try:
    number_of_versions = len(reg.get_model(qualified_model_name).show_versions())
    if number_of_versions > 0:
        print(
            f"Model {qualified_model_name} already exists. This notebook will build a new version."
        )
except Exception:
    print(f"This will be the first version of model {qualified_model_name}.")

Session warehouse:          FORECAST_MODEL_BUILDER_WH

Modeling Frequency:         day
Train Separate Lead Models: False
Model FORECAST_MODEL_BUILDER.MODELING.TEST_MODEL_1 already exists. This notebook will build a new version.


In [5]:
# Create Snowpark DataFrame from table in Snowflake
sdf = session.table(f"{TS_DB}.{TS_SCHEMA}.{TS_TABLE_NM}")

# Only keep the columns specified in the config
sdf = sdf.select(
    TIME_PERIOD_COLUMN, TARGET_COLUMN, *PARTITION_COLUMNS, *EXOGENOUS_COLUMNS
)

In [6]:
# -----------------------------------------
# Preliminary checks
# -----------------------------------------
# Verify valid rollup specification.
# Raise an error if the user specifies a rollup frequency that is finer grain than the current frequency
# Raise an error if the user does not specify a rollup aggregation for the target and all exogenous columns
verify_valid_rollup_spec(
    CURRENT_FREQUENCY, ROLLUP_FREQUENCY, ROLLUP_AGGREGATIONS, EXOGENOUS_COLUMNS
)

# Roughly verify the current frequency (datetime difference between consecutive records) of the time series data
# Note the existence of gaps if range is anything other than 1 - 1
verify_current_frequency(sdf, TIME_PERIOD_COLUMN, window_spec, CURRENT_FREQUENCY)

Most common time between consecutive records (frequency): 1.0 day(s)
    The current frequency appears to be in DAY granularity.
    The range of values is 1.0 - 1.0 day (s)
    


In [None]:
# NOTE: while the sample data in this repo has no gaps, real data may require inserting missing dates
# and filling with zero (or other relevant imputed value, depending on the business problem).
# To maintain compatibility with incremental feature view refresh, consider using a calendar table cross join

# Sample code:

"""
START_DATE = (
    sdf.select(F.min(TIME_PERIOD_COLUMN).alias(TIME_PERIOD_COLUMN))
    .collect()[0][TIME_PERIOD_COLUMN]
)

cal = (
    session.table(f"{TS_DB}.{TS_SCHEMA}.{CAL_TABLE_NM}")
    .select(F.col(CAL_DATE_COLUMN).alias(TIME_PERIOD_COLUMN))
)

time_series = (
    cal.filter(F.col(TIME_PERIOD_COLUMN)>=START_DATE)
    .join(sdf.select(PARTITION_COLUMNS).distinct(), how='cross')
)

sdf = (
    time_series.join(sdf, on=PARTITION_COLUMNS+[TIME_PERIOD_COLUMN], how='left')
    .fillna({TARGET_COLUMN:0})
)
"""

-----
# Feature Engineering
-----

In [7]:
# First Convert Decimal data types to Floats (because DecimalType doesn't work in modeling algorithms)
sdf_converted = sdf.select(
    [
        (
            F.col(field.name).cast(T.FloatType()).alias(field.name)
            if isinstance(field.datatype, T.DecimalType)
            else F.col(field.name)
        )
        for field in sdf.schema
    ]
)

# ------------------------------------------------------------------------
# ROLL UP to specified frequency
# ------------------------------------------------------------------------
sdf_rollup = roll_up(
    sdf_converted,
    TIME_PERIOD_COLUMN,
    PARTITION_COLUMNS,
    TARGET_COLUMN,
    EXOGENOUS_COLUMNS,
    ROLLUP_FREQUENCY,
    ROLLUP_AGGREGATIONS,
)

# ------------------------------------------------------------------------
# Create time-derived features
# ------------------------------------------------------------------------
sdf_engineered = expand_datetime(sdf_rollup, TIME_PERIOD_COLUMN, modeling_frequency)

# ------------------------------------------------------------------------
# Create rolling average of most recent time periods
# ------------------------------------------------------------------------
# NOTE: We can only generate recent rolling average features if we are training separate lead models (direct multi-step forecasting).
if CREATE_LAG_FEATURE & train_separate_lead_models:
    sdf_engineered = recent_rolling_avg(
        sdf_engineered, [TARGET_COLUMN], window_spec, modeling_frequency
    )

# ------------------------------------------------------------------------
# Create LAG features (and possibly LEAD feature) of the TARGET variable
# ------------------------------------------------------------------------
final_sdf = apply_functions_in_a_loop(
    train_separate_lead_models=train_separate_lead_models,
    partition_column_list=PARTITION_COLUMNS,
    input_sdf=sdf_engineered,
    target_column=TARGET_COLUMN,
    time_step_frequency=modeling_frequency,
    forecast_horizon=FORECAST_HORIZON,
    w_spec=window_spec,
    create_lag_feature=CREATE_LAG_FEATURE,
)

# Inspect data
print(f"Total record count after rolling up:   {sdf_rollup.count()}")
print(f"Total record count of final data:      {final_sdf.count()}")
final_sdf.show(2)

Total record count after rolling up:   367500
Total record count of final data:      367500
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"ORDER_TIMESTAMP"    |"TARGET"       |"FEATURE_1"    |"YEAR"  |"MONTH_SIN"         |"MONTH_COS"          |"WEEK_OF_YEAR_SIN"  |"WEEK_OF_YEAR_COS"   |"DAY_OF_WEEK_SUN"  |"DAY_OF_WEEK_MON"  |"DAY_OF_WEEK_TUE"  |"DAY_OF_WEEK_WED"  |"DAY_OF_WEEK_THU"  |"DAY_OF_WEEK_FRI"  |"DAY_OF_WEEK_SAT"  |"DAY_OF_YEAR_SIN"   |"DAY_OF_YEAR_COS"    |"DAYS_SINCE_JAN2020"  |"MODEL_TARGET"  |"GROUP_IDENTIFIER"  |"GROUP_IDENTIFIER_STRING"  |
--------------------------------------------------------------

-----
# Feature Store
-----

In [8]:
# Create or retrieve feature store based on session database and schema
fs = FeatureStore(
    session,
    database=session_db,
    name=session_schema,
    default_warehouse=SESSION_WH,
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
)

# Create and register entity based on partition column.
# If entity already exists, registration will be skipped.
entity = Entity(
    name="TS_PARTITION_ENTITY",
    join_keys=["GROUP_IDENTIFIER_STRING"],
)

fs.register_entity(entity)

  return f(self, *args, **kargs)


Entity(name=TS_PARTITION_ENTITY, join_keys=['GROUP_IDENTIFIER_STRING'], owner=None, desc=)

In [9]:
# Create feature view based on dataframe engineered above
# This allows the same logic to be applied to training and testing data,
# keeping features up to date on an incrementally refreshed schedule
fv = FeatureView(
    name="FORECAST_FEATURES",
    entities=[entity],
    feature_df=final_sdf,
    timestamp_col=TIME_PERIOD_COLUMN,
    refresh_freq="1 days",
    refresh_mode="INCREMENTAL",
)

# Automatically versions the feature view definition (not the data itself)
# If there are changes to the feature view (ex. a different transformation of the dataframe)
# the version will change. If there are no changes, the same version will be returned, and 
# feature view registration will be skipped, returning the existing feature view.
version = version_featureview(fv)

fv_reg = fs.register_feature_view(fv, version=version)


  return self._get_feature_view_if_exists(feature_view.name, str(version))


-----
# TRAIN/TEST SPLIT
-----

In [10]:
# TRAIN/TEST SPLIT

sdf_fv = fs.read_feature_view(fv_reg).cache_result()

# TRAIN/VALIDATION SPLIT
if VALIDATION_DAYS == 0:
    sdf_train = sdf_fv
elif VALIDATION_DAYS > 0:
    # Get the last time period in the dataset
    last_time_period = sdf_fv.select(
        F.max(TIME_PERIOD_COLUMN).alias("MAX_DTTM")
    ).collect()[0]["MAX_DTTM"]
    # Remove the validation records from the training set
    sdf_train = sdf_fv.filter(
        F.date_trunc("day", TIME_PERIOD_COLUMN)
        < F.dateadd("day", F.lit(-VALIDATION_DAYS), F.lit(last_time_period))
    )
    sdf_test = sdf_fv.filter(
        F.date_trunc("day", TIME_PERIOD_COLUMN)
        >= F.dateadd("day", F.lit(-VALIDATION_DAYS), F.lit(last_time_period))
    )

# Inspect the data
training_dttm_boundaries = sdf_train.select(
    F.min(TIME_PERIOD_COLUMN).alias("MIN_DTTM"),
    F.max(TIME_PERIOD_COLUMN).alias("MAX_DTTM"),
).collect()[0]
print(f"Training set row count: {sdf_train.count()}")
print(f"First time period in training set: {training_dttm_boundaries['MIN_DTTM']}")
print(f"Last time period in training set:  {training_dttm_boundaries['MAX_DTTM']}")
if len(PARTITION_COLUMNS) > 0:
    print(
        f"Total Partition Count: {sdf_train.select(F.get(F.split('GROUP_IDENTIFIER_STRING',F.lit('_LEAD')),0)).distinct().count()}"
    )
    model_count = sdf_train.select('GROUP_IDENTIFIER_STRING').distinct().count()
    USE_CONTEXT = model_count <= MODELCONTEXT_MAX
else:
    print("No partitions specified.")
    USE_CONTEXT = False
sdf_train.show(2)

Training set row count: 344750
First time period in training set: 2021-01-01 00:00:00
Last time period in training set:  2024-10-10 00:00:00
Total Partition Count: 250
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"ORDER_TIMESTAMP"    |"TARGET"       |"FEATURE_1"    |"YEAR"  |"MONTH_SIN"  |"MONTH_COS"              |"WEEK_OF_YEAR_SIN"  |"WEEK_OF_YEAR_COS"   |"DAY_OF_WEEK_SUN"  |"DAY_OF_WEEK_MON"  |"DAY_OF_WEEK_TUE"  |"DAY_OF_WEEK_WED"  |"DAY_OF_WEEK_THU"  |"DAY_OF_WEEK_FRI"  |"DAY_OF_WEEK_SAT"  |"DAY_OF_YEAR_SIN"   |"DAY_OF_YEAR_COS"    |"DAYS_SINCE_JAN2020"  |"MODEL_TARGET"  |"GROUP_IDENTIFIER"  |"GROUP_IDENTIFIER_STR

In [None]:
# Save training and testing datasets for future use

dataset_name = "FORECAST_FEATURES"
dataset = Dataset.create(session, name=dataset_name+"_TRAIN", exist_ok=True)

# Version the data itself; If the data changes on a subsequent run, a new dataset will be created.
# If the data does not change, the previously saved dataset will be returned.
ds_train_version = version_data(sdf_train)

if ds_train_version in dataset.list_versions():
    ds_train = dataset.select_version(ds_train_version)
else:
    ds_train = dataset.create_version(
        version=ds_train_version,
        input_dataframe=sdf_train,
        label_cols=["MODEL_TARGET"],        
    )

if VALIDATION_DAYS > 0:
    dataset = Dataset.create(session, name=dataset_name+"_TEST", exist_ok=True)

    ds_test_version = version_data(sdf_test)
    
    if ds_test_version in dataset.list_versions():
        ds_test = dataset.select_version(ds_test_version)
    else:
        ds_test = dataset.create_version(
            version=ds_test_version,
            input_dataframe=sdf_test,
            label_cols=["MODEL_TARGET"],        
        )



-----
# Model Training
-----

In [None]:
# --------------------------------------------------------
# Define and register a UDTF to perform model training
# --------------------------------------------------------

# # Get all of the column names except the partition columns and the column LEAD
training_udtf_input_col_nms = [
    colnm
    for colnm in sdf_train.columns
    if colnm not in ["GROUP_IDENTIFIER", "GROUP_IDENTIFIER_STRING"]
]


def train_model(df: pd.DataFrame) -> pd.DataFrame:
    """Trains a forecasting model and returns the model binary and metadata.

    Parameters
    ----------
    df : pandas.DataFrame
        The input DataFrame.

    Returns
    -------
    pandas.DataFrame
        A DataFrame containing the model binary and metadata.

    """
    # NOTE: In a vectorized UDTF we need to RENAME the columns to match the input dataset
    df.columns = training_udtf_input_col_nms

    # Set the index
    df = df.set_index(pd.to_datetime(df.pop(TIME_PERIOD_COLUMN)))

    # Create X and y dataframes.
    X = df.drop(columns=[TARGET_COLUMN, "MODEL_TARGET"])
    y = df["MODEL_TARGET"]

    # train a model
    model = xgb.XGBRegressor(**XGB_PARAMS)
    model.fit(X, y)
    # Save the model
    raw_model = json.loads(model.get_booster().save_raw(raw_format='json'))
    model_binary = pickle.dumps(raw_model)

    # Obtain feature importances
    feature_importance_dict = dict(
        zip(X.columns, [float(val) for val in model.feature_importances_])
    )
    metadata = {
        "feature_importance": feature_importance_dict,
    }

    # Save the environment specs
    module_dict = {}
    for finder, module_name, is_pkg in pkgutil.iter_modules():
        try:
            distribution = importlib.metadata.distribution(module_name)
            version = distribution.version
            module_dict[module_name] = version
        except importlib.metadata.PackageNotFoundError:
            continue
    model_df = pd.DataFrame(
        [[model.__class__.__name__, model_binary, metadata, module_dict]],
        columns=["ALGORITHM", "MODEL_BINARY", "METADATA","ENVIRONMENT_SPECS"],
    )

    return model_df


# Define UDTF class
class ModelTrainingUDTF:
    """Class which is registered as a UDTF to train forecasting models."""

    def end_partition(self, df):
        """End partition method which utilizes the train model function."""
        forecast_df = train_model(df)
        yield forecast_df


# Get the data types for the input dataframe
vect_udtf_input_dtypes = [
    T.PandasDataFrameType(
        [
            field.datatype
            for field in sdf_train.schema.fields
            if field.name not in ["GROUP_IDENTIFIER", "GROUP_IDENTIFIER_STRING"]
        ]
    )
]

# Register the class as a temporary UDTF
# Give the UDTF a unique name so that it doesn't conflict with anyone else running the same notebook
udtf_name = f"MODEL_TRAINER_{MODEL_NAME}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}__{random.randint(1, 999)}"
session.udtf.register(
    ModelTrainingUDTF,
    name=udtf_name,
    input_types=vect_udtf_input_dtypes,
    output_schema=T.PandasDataFrameType(
        [T.StringType(), T.BinaryType(), T.VariantType(), T.VariantType()],
        ["ALGORITHM", "MODEL_BINARY", "METADATA","ENVIRONMENT_SPECS"],
    ),
    packages=[
        "snowflake-snowpark-python",
        "pandas",
        "numpy",
        "xgboost",
        "scikit-learn",
    ],
    replace=True,
    is_permanent=False,
    comment=query_tag,
)

print("Registration complete")

Registration complete


In [None]:
session.use_warehouse(MODELING_WH)

# Before model training, remove records where MODEL_TARGET is null
sdf_train = sdf_train.filter(F.col("MODEL_TARGET").isNotNull())

# Run the UDTF
udtf_models = sdf_train.select(
    "GROUP_IDENTIFIER",
    "GROUP_IDENTIFIER_STRING",
    F.call_table_function(udtf_name, *training_udtf_input_col_nms).over(
        partition_by=["GROUP_IDENTIFIER", "GROUP_IDENTIFIER_STRING"],
        order_by=TIME_PERIOD_COLUMN,
    ),
)

# Add additional columns to the output
if train_separate_lead_models:
    total_leads_modeled_this_run = FORECAST_HORIZON
elif not train_separate_lead_models:
    total_leads_modeled_this_run = None

udtf_models = udtf_models.select(
    "GROUP_IDENTIFIER",
    "GROUP_IDENTIFIER_STRING",
    F.lit(MODEL_NAME).alias("MODEL_NAME"),
    "ALGORITHM",
    F.lit(run_dttm).alias("MODEL_TRAINED_DTTM"),
    "MODEL_BINARY",
    "METADATA",
    "ENVIRONMENT_SPECS",
)

# Cache results for faster downstream usage of the udtf_models DataFrame
udtf_models = udtf_models.cache_result()

# Switch back to the original warehouse
session.use_warehouse(SESSION_WH)

# Function to load xgb model from raw json
def get_xgb_model(raw_model):
    with tempfile.NamedTemporaryFile(suffix=".json", mode='w', delete=False) as f:
            json.dump(raw_model,f)
            model_filename = f.name
    model = xgb.XGBRegressor(**XGB_PARAMS)
    model.load_model(model_filename)
    os.remove(model_filename)
    return model

print("Model training complete.")

Model training complete.


-----
# Model Registry
-----

In [None]:
# --------------------------------------------------------
# Define the Partitioned Custom Model
# --------------------------------------------------------

# Sample input (based on feature view table) to track lineage
sample_input = sdf_train.limit(100).drop("GROUP_IDENTIFIER")

# Input features
model_input_predictor_features = [
    colnm
    for colnm in sdf_train.columns
    if colnm
    not in [
        "GROUP_IDENTIFIER",
        "GROUP_IDENTIFIER_STRING",
        TIME_PERIOD_COLUMN,
        TARGET_COLUMN,
        "MODEL_TARGET",
    ]
]

# Custom model uses model context if provided, otherwise uses model storage table binaries
class ForecastingModel(custom_model.CustomModel):
    """Custom model class."""

    def __init__(self, context: Optional[custom_model.ModelContext] = None) -> None:
        """Initialize object."""
        super().__init__(context)
        self.partition_id = None
        self.model = None

    @custom_model.partitioned_api
    def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
        """Make predictions using unpickled model."""
        if self.partition_id != input_df["GROUP_IDENTIFIER_STRING"][0]:
            self.partition_id = input_df["GROUP_IDENTIFIER_STRING"][0]

            # Use model context if it exists
            if len(self.context.model_refs):
                self.model = self.context.model_ref(self.partition_id)
            else:
                # Get the model binary from the first row of the input DataFrame where the column is not null
                raw_model = pickle.loads(
                    input_df.loc[
                        input_df["MODEL_BINARY"].first_valid_index(), "MODEL_BINARY"
                    ]
                )
                self.model = get_xgb_model(raw_model)

        model_output = self.model.predict(input_df[model_input_predictor_features])
        res = pd.DataFrame(model_output, columns=["_PRED_"])
        res["GROUP_IDENTIFIER_STRING_OUT_"] = input_df["GROUP_IDENTIFIER_STRING"]
        res[TIME_PERIOD_COLUMN+"_OUT_"] = input_df[TIME_PERIOD_COLUMN]
        return res

if USE_CONTEXT:
    # Create model context based on json from dataframe
    model_dict = {row.GROUP_IDENTIFIER_STRING:get_xgb_model(pickle.loads(row.MODEL_BINARY)) for row in udtf_models.collect()}
    context = custom_model.ModelContext(models=model_dict)
else:
    # Create Model Storage table if it does not already exist
    # It will be created in the schema associated with the notebook (which is the schema that was created for this project).
    session.sql(
        f"""
            create table if not exists {MODEL_BINARY_STORAGE_TBL_NM} (
                GROUP_IDENTIFIER VARIANT,
                GROUP_IDENTIFIER_STRING VARCHAR,
                MODEL_NAME VARCHAR(100),
                MODEL_VERSION VARCHAR(100),
                ALGORITHM VARCHAR(100),
                MODEL_TRAINED_DTTM TIMESTAMP,
                MODEL_BINARY BINARY,
                METADATA VARIANT,
                ENVIRONMENT_SPECS VARIANT
                )
            comment = '{query_tag}'
    """
    ).collect()
    context = None
    user_settings_dict["MODEL_BINARY_STORAGE_TBL_NM"] = MODEL_BINARY_STORAGE_TBL_NM
    # Add model storage to sample input
    sample_input = sample_input.join(
        udtf_models.select("GROUP_IDENTIFIER_STRING","MODEL_BINARY"),
        on = "GROUP_IDENTIFIER_STRING",
    )
    print(f"Number of models ({model_count}), greater than model context max. Model using storage table approach")


m = ForecastingModel(context)

# --------------------------------------------------------
# Log Model to Model Registry
# --------------------------------------------------------


# Log the model to the model registry
options = {"function_type": "TABLE_FUNCTION", "relax_version": False}
user_settings_dict["USE_CONTEXT"] = USE_CONTEXT
user_settings_dict["TARGET_COLUMN"] = "MODEL_TARGET"
metrics_to_log = {
    "direct_multi_step_forecasting": train_separate_lead_models,
    "frequency": modeling_frequency,
    "training_data_start": training_dttm_boundaries["MIN_DTTM"].strftime(
        "%Y-%m-%d %H:%M:%S"
    ),
    "training_data_end": training_dttm_boundaries["MAX_DTTM"].strftime(
        "%Y-%m-%d %H:%M:%S"
    ),
    "user_settings": user_settings_dict,
    "train_dataset": {"name":ds_train.fully_qualified_name, "version":ds_train.selected_version.name},
    "test_dataset": {"name":ds_test.fully_qualified_name, "version":ds_test.selected_version.name},
}
mv = reg.log_model(
    m,
    model_name=qualified_model_name,
    options=options,
    metrics=metrics_to_log,
    conda_dependencies=["pandas", "xgboost"],
    sample_input_data=sample_input,
    #signatures={"predict": signature},
    comment=query_tag,
)

# In addition to setting the query tag for the model version, we also set it for the model itself
reg.get_model(qualified_model_name).comment = query_tag

print(f"Model version name: {mv.version_name}")

# Confirm that the new model/version is in the registry
reg.show_models()

Model logged successfully.: 100%|██████████| 6/6 [03:11<00:00, 31.96s/it]                          
Model version name: LIGHT_HOUND_1


Unnamed: 0,created_on,name,model_type,database_name,schema_name,comment,owner,default_version_name,versions,aliases
0,2025-10-09 13:46:18.962000-07:00,TEST_MODEL_1,USER_MODEL,FORECAST_MODEL_BUILDER,MODELING,"{""origin"":""sf_sit"", ""name"":""sit_forecasting"", ...",ML_DEV_ROLE,ROTTEN_GOOSE_2,"[""LIGHT_HOUND_1"",""ROTTEN_GOOSE_2"",""SLIPPERY_SN...","{""DEFAULT"":""ROTTEN_GOOSE_2"",""FIRST"":""SLIPPERY_..."


In [17]:
# --------------------------------------------------------
# Set the model as the default version in the registry
# If using model storage table, save models there
# --------------------------------------------------------

if SAVE_MODEL_VERSION_THIS_RUN:
        # Set default version of the model to this version name
    reg.get_model(qualified_model_name).default = mv.version_name

    if not USE_CONTEXT:
        # Append model binaries and metadata to the model binary storage table in Snowflake
        udtf_models_w_version = udtf_models.with_column(
            "MODEL_VERSION", F.lit(mv.version_name)
        ).select(session.table(f"{MODEL_BINARY_STORAGE_TBL_NM}").columns)

        udtf_models_w_version.write.save_as_table(
            f"{MODEL_BINARY_STORAGE_TBL_NM}", mode="append"
        )


    print(
        f"Model version '{mv.version_name}' set as the default version in the registry."
    )
else:
    print(
        f"""Model version '{mv.version_name}' will be deleted from the registry at the end of this notebook.
    If you wish to save this version, set SAVE_MODEL_VERSION_THIS_RUN = True."""
    )

# Look at the most recent 3 versions of the model
reg.get_model(qualified_model_name).show_versions().tail(3)

Model version 'LIGHT_HOUND_1' set as the default version in the registry.


Unnamed: 0,created_on,name,aliases,comment,database_name,schema_name,model_name,is_default_version,functions,metadata,user_data,model_attributes,size,environment,runnable_in,inference_services
0,2025-10-09 13:46:19.008000-07:00,SLIPPERY_SNAKE_2,"[""FIRST""]","{""origin"":""sf_sit"", ""name"":""sit_forecasting"", ...",FORECAST_MODEL_BUILDER,MODELING,TEST_MODEL_1,False,"[""PREDICT""]","{""metrics"": {""direct_multi_step_forecasting"": ...",{},"{""framework"":""custom"",""client"":""snowflake-ml-p...",112164418,"{""default"":{""python_version"":""3.12"",""cuda_vers...","[""WAREHOUSE""]",[]
1,2025-10-09 13:57:36.875000-07:00,ROTTEN_GOOSE_2,[],"{""origin"":""sf_sit"", ""name"":""sit_forecasting"", ...",FORECAST_MODEL_BUILDER,MODELING,TEST_MODEL_1,False,"[""PREDICT""]","{""metrics"": {""direct_multi_step_forecasting"": ...",{},"{""framework"":""custom"",""client"":""snowflake-ml-p...",4275486,"{""default"":{""python_version"":""3.12"",""snowflake...","[""WAREHOUSE""]",[]
2,2025-10-13 11:40:42.696000-07:00,LIGHT_HOUND_1,"[""DEFAULT"",""LAST""]","{""origin"":""sf_sit"", ""name"":""sit_forecasting"", ...",FORECAST_MODEL_BUILDER,MODELING,TEST_MODEL_1,True,"[""PREDICT""]","{""metrics"": {""direct_multi_step_forecasting"": ...",{},"{""framework"":""custom"",""client"":""snowflake-ml-p...",112164510,"{""default"":{""python_version"":""3.12"",""cuda_vers...","[""WAREHOUSE""]",[]


-----
# Clean up
-----

In [26]:
# If we don't want to keep the version we just built, we can remove it from the registry

# NOTE: Comment this code out if you do not want to delete the model version from the model registry
# If the user does not want to save the current version, delete this version of the model from the registry.
current_model = reg.get_model(qualified_model_name)

if not SAVE_MODEL_VERSION_THIS_RUN:
    deletion_message = ""
    try:
        current_model.version(mv.version_name)
    except Exception:
        deletion_message = f"WARNING: Model version '{mv.version_name}' does not exist in the registry."
        print(deletion_message)

    if len(deletion_message) == 0:
        try:
            if len(current_model.versions()) == 0:
                print(
                    f"WARNING: There are no versions for model '{MODEL_NAME}' in the registry."
                )
            elif (len(current_model.versions()) == 1) & (
                current_model.default.version_name == mv.version_name
            ):
                reg.delete_model(MODEL_NAME)
                print(
                    f" Model '{MODEL_NAME}' (which only had one version: '{mv.version_name}') was deleted from the registry."
                )
            else:
                current_model.delete_version(mv.version_name)
                print(
                    f"Model version '{mv.version_name}' was deleted from the registry."
                )
        except Exception:
            print(
                f"WARNING: Model version '{mv.version_name}' was not able to be deleted from the registry."
            )

    reg.show_models()