In [91]:
import pandas as pd
from datetime import datetime
from dateutil.relativedelta import relativedelta
from forecast_model_builder.utils import connect
from snowflake.ml.registry import registry
from snowflake.ml.feature_store import FeatureStore, CreationMode
import snowflake.snowpark.functions as F

In [None]:
# Establish session
session = connect(connection_name="default")
session_db = session.connection.database
session_schema = session.connection.schema
session_wh = session.connection.warehouse
print(f"Session db.schema: {session_db}.{session_schema}")
print(f"Session warehouse: {session_wh}")

# Query tag
query_tag = "{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"inference"}}'
session.query_tag = query_tag

Session db.schema: FORECAST_MODEL_BUILDER.TEST
Session warehouse: FORECAST_MODEL_BUILDER_WH


In [95]:
# VARS

# Table name to store the PREDICTION results.
# NOTE: If the table name is not fully qualified with DB.SCHEMA, the session's default database and schema will be used.
INFERENCE_RESULT_TBL_NM = "FORECAST_PREDICTIONS"

# Name of the model to use for inference, as well as the Database and Schema of the model registry.
# NOTE: The default model version from the registry will be used.
MODEL_DB = "FORECAST_MODEL_BUILDER"
MODEL_SCHEMA = "MODELING"
MODEL_NAME = "TEST_MODEL_1"

# Scaling up the warehouse may speed up execution time, especially if there are many partitions.
# NOTE: If set to None, then the session warehouse will be used.
INFERENCE_WH = "STANDARD_XL"

REFRESH_WITH_FEATURE_VIEW = False
# If false, must provide schedule
TASK_SCHEDULE = "1 day"
# Inference start date (don't want to do entire history on first run)
INFERENCE_START_DATE = "2025-01-01"

In [93]:
# get objects needed

# -----------------------------------------------------------------------
# Notebook Warehouse
# -----------------------------------------------------------------------
SESSION_WH = session.connection.warehouse
print(f"Session warehouse:          {SESSION_WH}")

# -----------------------------------------------------------------------
# Check Inference Warehouse
# -----------------------------------------------------------------------
# Check that the user specified an available warehouse as INFERENCE_WH. If not, use the session warehouse.
available_warehouses = [
    row["NAME"]
    for row in session.sql("SHOW WAREHOUSES")
    .select(F.col('"name"').alias("NAME"))
    .collect()
]

if INFERENCE_WH in available_warehouses:
    print(f"Inference warehouse:        {INFERENCE_WH} \n")
else:
    print(
        f"WARNING: User does not have access to INFERENCE_WH = '{INFERENCE_WH}'. Inference will use '{SESSION_WH}' instead. \n"
    )
    INFERENCE_WH = SESSION_WH

# -----------------------------------------------------------------------
# Fully qualified MODEL NAME
# -----------------------------------------------------------------------
qualified_model_name = f"{MODEL_DB}.{MODEL_SCHEMA}.{MODEL_NAME}"

# -----------------------------------------------------------------------
# Get the model and the version name of the default version
# -----------------------------------------------------------------------
# Establish registry object
reg = registry.Registry(
    session=session, database_name=MODEL_DB, schema_name=MODEL_SCHEMA
)

# Get the model from the registry
mv = reg.get_model(qualified_model_name).default

# Get the default version name
model_version_nm = mv.version_name

print(f"Model Version:              {model_version_nm}")

# --------------------------------
# User Constants from Model Setup
# --------------------------------
stored_constants = mv.show_metrics()["user_settings"]
USE_CONTEXT = stored_constants["USE_CONTEXT"]
TIME_PERIOD_COLUMN = stored_constants["TIME_PERIOD_COLUMN"]
ROLLUP_FREQUENCY = stored_constants["ROLLUP_FREQUENCY"]
CURRENT_FREQUENCY = stored_constants["CURRENT_FREQUENCY"]
FORECAST_HORIZON = stored_constants["FORECAST_HORIZON"]
frequency= ROLLUP_FREQUENCY if ROLLUP_FREQUENCY else CURRENT_FREQUENCY

if not USE_CONTEXT:
    MODEL_BINARY_STORAGE_TBL_NM = stored_constants["MODEL_BINARY_STORAGE_TBL_NM"]

# --------------------------------
# Get features for inference
# --------------------------------

fs = FeatureStore(
    session,
    database=session_db,
    name=session_schema,
    default_warehouse=SESSION_WH,
    creation_mode=CreationMode.FAIL_IF_NOT_EXIST,
)

fv = mv.lineage(direction='upstream')[0]

Session warehouse:          FORECAST_MODEL_BUILDER_WH

Model Version:              COWARDLY_BADGER_3




In [96]:
inference_table_name = f"{session_db}.{session_schema}.{INFERENCE_RESULT_TBL_NM}"
INFERENCE_START_DATE = datetime.strptime(INFERENCE_START_DATE,"%Y-%m-%d")
try:
    result_table = session.table(inference_table_name)
    result_table.collect()
except:
    print(f"Table {INFERENCE_RESULT_TBL_NM} does not exist and will be created once inference is complete")
else:
    INFERENCE_START_DATE = result_table.select(
        F.dateadd(
            frequency,
            F.lit(1),
            F.max(TIME_PERIOD_COLUMN)).alias(TIME_PERIOD_COLUMN)
    ).collect()[0][TIME_PERIOD_COLUMN]

INFERENCE_END_DATE = datetime.today() + relativedelta(**{frequency+"s":FORECAST_HORIZON})

Table FORECAST_PREDICTIONS does not exist and will be created once inference is complete


In [None]:
# get and filter data

sdf = fs.read_feature_view(fv).filter(
    (F.col(TIME_PERIOD_COLUMN)>=INFERENCE_START_DATE)
    & (F.col(TIME_PERIOD_COLUMN)<INFERENCE_END_DATE)
)
sdf.show()

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"ORDER_TIMESTAMP"    |"TARGET"       |"FEATURE_1"    |"YEAR"  |"MONTH_SIN"         |"MONTH_COS"         |"WEEK_OF_YEAR_SIN"   |"WEEK_OF_YEAR_COS"  |"DAY_OF_WEEK_SUN"  |"DAY_OF_WEEK_MON"  |"DAY_OF_WEEK_TUE"  |"DAY_OF_WEEK_WED"  |"DAY_OF_WEEK_THU"  |"DAY_OF_WEEK_FRI"  |"DAY_OF_WEEK_SAT"  |"DAY_OF_YEAR_SIN"    |"DAY_OF_YEAR_COS"   |"DAYS_SINCE_JAN2020"  |"MODEL_TARGET"  |"GROUP_IDENTIFIER"   |"GROUP_IDENTIFIER_STRING"  |
----------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
# if inference table doesn't exist, create it

# get model and version
# IF refresh every fv refresh
    # get fv
    # make stream
# ELSE
    # schedule for task
    # if direct multistep, must be <= forecast horizon

In [None]:
# TASK

# get model version
# get fv and read to df
# filter df to correct dates
    # > max date in inf (consider complexity of direct multistep here, might be different)
    # if none then today
    # < forecast horizon
# run inference
# add columns like inference date, model version, etc.
# append data

In [None]:
# demonstrate force rerun

In [None]:
# handling previous existing tasks