# SP500 Stock Demo — Notebook 04: Inference + Drift Monitoring

- Load latest registered model
- Generate batch predictions on recent data
- Persist predictions table
- Compute simple drift (PSI) on key features using Snowpark
- Optional: emit alert when PSI exceeds threshold


In [None]:
# 0) Imports and session/context
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, avg
from snowflake.ml.registry import Registry

session = get_active_session()
session.sql("USE WAREHOUSE DEMO_WH_M").collect()
session.sql("USE DATABASE SP500_STOCK_DEMO").collect()
session.sql("USE SCHEMA DATA").collect()

# Confirm enriched feature table exists
session.table('PRICE_FEATURES').limit(5).show()


In [None]:
# 1) Load latest model from Model Registry
reg = Registry(session=session, database_name='SP500_STOCK_DEMO', schema_name='DATA')
model = reg.get_model('XGB_SP500_RET3M').last()  # latest version
print(model.version)


In [None]:
# 2) Prepare recent feature slice (last 5 days)
cutoff = session.sql("select dateadd('day', -5, max(TS)) as c from PRICE_FEATURES").collect()[0]['C']
recent = session.table('PRICE_FEATURES').filter(col('TS') >= cutoff)
recent.count()


In [None]:
# 3) Run batch predictions and persist
preds = model.run(recent, function_name='PREDICT')
# Persist
preds.write.save_as_table('PREDICTIONS_SP500_RET3M', mode='overwrite')

# Quick peek
session.table('PREDICTIONS_SP500_RET3M').limit(5).show()


In [None]:
# 4) Simple drift metric (PSI) on select features
from snowflake.snowpark.functions import lit

# Define reference (older window) vs current (recent)
ref_cutoff = session.sql("select dateadd('day', -35, max(TS)) as c from PRICE_FEATURES").collect()[0]['C']
reference = session.table('PRICE_FEATURES').filter((col('TS') >= ref_cutoff) & (col('TS') < cutoff))
recent_slice = recent  # avoid reserved word "current"

# Register temp views so the SQL string can reference them
reference.create_or_replace_temp_view('REF_WINDOW')
recent_slice.create_or_replace_temp_view('CURR_WINDOW')

# Feature list for drift check
feat_cols = ['RET_1','SMA_5','SMA_20','VOL_20']

# Compute basic PSI using equal-width bins in Snowflake SQL
num_bins = 10
psi_results = []
for f in feat_cols:
    stats = session.sql(f"""
        WITH bounds AS (
          SELECT MIN({f}) AS mn, MAX({f}) AS mx FROM PRICE_FEATURES
        ), bins AS (
          SELECT mn + (mx-mn)*seq4()/{num_bins} AS edge
          FROM bounds, TABLE(GENERATOR(ROWCOUNT => {num_bins}+1))
        ),
        ref AS (
          SELECT WIDTH_BUCKET({f}, (SELECT MIN(edge) FROM bins), (SELECT MAX(edge) FROM bins), {num_bins}) AS b
          FROM REF_WINDOW
        ), cur AS (
          SELECT WIDTH_BUCKET({f}, (SELECT MIN(edge) FROM bins), (SELECT MAX(edge) FROM bins), {num_bins}) AS b
          FROM CURR_WINDOW
        ),
        refc AS (
          SELECT b, COUNT(*) AS cnt FROM ref GROUP BY b
        ), curc AS (
          SELECT b, COUNT(*) AS cnt FROM cur GROUP BY b
        )
        SELECT COALESCE(curc.b, refc.b) AS bin,
               COALESCE(refc.cnt,0) AS ref_cnt,
               COALESCE(curc.cnt,0) AS cur_cnt
        FROM refc FULL OUTER JOIN curc ON refc.b = curc.b
    """)
    df = stats.to_pandas()
    if df.empty:
        psi_results.append((f, None))
        continue
    ref_total = df['REF_CNT'].sum() or 1
    cur_total = df['CUR_CNT'].sum() or 1
    ref_pct = (df['REF_CNT'] / ref_total).replace(0, 1e-6)
    cur_pct = (df['CUR_CNT'] / cur_total).replace(0, 1e-6)
    psi = ((cur_pct - ref_pct) * (cur_pct / ref_pct).apply(lambda x: 0 if x<=0 else __import__('math').log(x))).sum()
    psi_results.append((f, float(psi)))

print(dict(psi_results))


### Snowflake Model Monitor (native) setup

This section creates a small baseline slice and a Model Monitor tied to the latest predictions table. It uses native Snowflake Monitoring so you can view metrics in Snowsight under AI & ML > Models.


In [None]:
-- Baseline + Monitor creation (safe re-run)
USE WAREHOUSE DEMO_WH_M;
USE DATABASE SP500_STOCK_DEMO;
USE SCHEMA DATA;

-- Baseline: earliest 14 days of predictions
CREATE OR REPLACE TABLE BASELINE_PREDICTIONS AS
WITH b AS (SELECT MIN(TS) AS mn FROM PREDICTIONS_SP500_RET3M)
SELECT p.*
FROM PREDICTIONS_SP500_RET3M p, b
WHERE p.TS >= b.mn AND p.TS < DATEADD('day', 14, b.mn);

-- Model Monitor (regression-like scores)
CREATE OR REPLACE MODEL MONITOR SP500_RET3M_MONITOR
WITH 
  MODEL = SP500_STOCK_DEMO.DATA.XGB_SP500_RET3M,
  VERSION = 'V_1',
  FUNCTION = 'PREDICT',
  SOURCE = SP500_STOCK_DEMO.DATA.PREDICTIONS_SP500_RET3M,
  BASELINE = SP500_STOCK_DEMO.DATA.BASELINE_PREDICTIONS,
  WAREHOUSE = DEMO_WH_M,
  REFRESH_INTERVAL = '1 DAY',
  AGGREGATION_WINDOW = '7 DAYS',
  TIMESTAMP_COLUMN = TS,
  ID_COLUMNS = ('TICKER'),
  PREDICTION_SCORE_COLUMNS = ('PREDICTED_RETURN');

-- Verify
SHOW MODEL MONITORS LIKE 'SP500_RET3M_MONITOR';


### Drift alert (skeleton)

Creates a Snowflake alert that checks PSI drift on `PREDICTED_RETURN` comparing the last 7 days to the baseline. The alert is created and immediately suspended (paused). You can later switch this to use native monitor metric functions when available.


In [None]:
-- Create a paused alert for PSI drift on PREDICTED_RETURN
USE WAREHOUSE DEMO_WH_M;
USE DATABASE SP500_STOCK_DEMO;
USE SCHEMA DATA;

-- Helper view: compute PSI between BASELINE_PREDICTIONS and last 7 days of PREDICTIONS_SP500_RET3M
CREATE OR REPLACE VIEW PSI_PREDICTED_RETURN_7D AS
WITH bounds AS (
  SELECT MIN(PREDICTED_RETURN) AS mn, MAX(PREDICTED_RETURN) AS mx FROM (
    SELECT PREDICTED_RETURN FROM BASELINE_PREDICTIONS
    UNION ALL
    SELECT PREDICTED_RETURN FROM PREDICTIONS_SP500_RET3M
  )
), bins AS (
  SELECT mn + (mx-mn)*seq4()/10 AS edge FROM bounds, TABLE(GENERATOR(ROWCOUNT => 11))
), ref AS (
  SELECT WIDTH_BUCKET(PREDICTED_RETURN, (SELECT MIN(edge) FROM bins), (SELECT MAX(edge) FROM bins), 10) AS b,
         COUNT(*) AS cnt
  FROM BASELINE_PREDICTIONS
  GROUP BY 1
), cur AS (
  SELECT WIDTH_BUCKET(PREDICTED_RETURN, (SELECT MIN(edge) FROM bins), (SELECT MAX(edge) FROM bins), 10) AS b,
         COUNT(*) AS cnt
  FROM PREDICTIONS_SP500_RET3M
  WHERE TS >= DATEADD('day', -7, CURRENT_DATE())
  GROUP BY 1
), joined AS (
  SELECT COALESCE(cur.b, ref.b) AS bin,
         COALESCE(ref.cnt,0) AS ref_cnt,
         COALESCE(cur.cnt,0) AS cur_cnt
  FROM ref FULL OUTER JOIN cur ON ref.b = cur.b
)
SELECT SUM(
  (NULLIF(cur_cnt,0)/NULLIFZERO(SUM(cur_cnt) OVER ()) - NULLIF(ref_cnt,0)/NULLIFZERO(SUM(ref_cnt) OVER ())) *
  LN( NULLIF(cur_cnt,0)/NULLIFZERO(SUM(cur_cnt) OVER ()) / NULLIF(ref_cnt,0)*NULLIFZERO(1.0/SUM(ref_cnt) OVER ()) )
) AS PSI
FROM joined;

-- Note: handle potential divide-by-zero by falling back to 1e-6 with NULLIFZERO pattern

CREATE OR REPLACE ALERT IF NOT EXISTS PSI_DRIFT_ALERT
  WAREHOUSE = DEMO_WH_M
  SCHEDULE = 'USING CRON 0 13 * * * UTC'  -- daily at 13:00 UTC
  IF (
    SELECT 1
    FROM PSI_PREDICTED_RETURN_7D
    WHERE PSI > 0.2  -- threshold; tweak as desired
  )
  THEN CALL SYSTEM$SEND_EMAIL(
    'ML_ALERTS',
    'harley.chen@snowflake.com',
    'PSI drift detected for PREDICTED_RETURN',
    'PSI over the last 7 days exceeded 0.2. Review SP500_RET3M_MONITOR in Snowsight.'
  );

-- Pause the alert by default; resume when ready
ALTER ALERT PSI_DRIFT_ALERT SUSPEND;
