# SP500 Stock Demo — Notebook 06: Streamlit App (Snowflake)

- Minimal app: select ticker and model version
- Plot actual vs predicted returns
- Show latest PSI per feature


In [None]:
# 0) Streamlit app (Snowflake)
import streamlit as st
import pandas as pd
import numpy as np
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col

# Attach to Snowflake session and set context
session = get_active_session()
session.sql("USE DATABASE SP500_STOCK_DEMO").collect()
session.sql("USE SCHEMA DATA").collect()

st.set_page_config(page_title="SP500 ML Forecasts", layout="wide")
st.title("SP500 Forecasts and Monitoring")
st.caption("Model: XGB_SP500_RET3M — training, inference, drift and explainability")

# Sidebar controls
with st.sidebar:
    st.header("Controls")
    # Model versions (from catalog)
    models_df = session.sql("""
        SELECT name, versions
        FROM DATA.SNOWFLAKE_ML_MODELS
        WHERE name = 'XGB_SP500_RET3M'
    """).to_pandas()
    versions = []
    if not models_df.empty:
        import ast as _ast
        versions = _ast.literal_eval(models_df.iloc[0]['VERSIONS'])
    selected_version = st.selectbox("Model version", options=versions if versions else ["V_1"])

    # Data source for chart
    source_mode = st.radio(
        "Prediction source",
        options=["Existing predictions", "On-demand scoring"],
        index=0,
        help="Use persisted predictions or run the selected model on-the-fly for the range"
    )

    # Ticker choices
    try:
        tickers = session.table('SP500_TICKERS').select('TICKER').to_pandas()['TICKER'].tolist()
    except Exception:
        tickers = session.table('PRICE_FEATURES').select('TICKER').distinct().to_pandas()['TICKER'].tolist()
    ticker = st.selectbox("Ticker", options=sorted(tickers)[:500])

    # Date range defaults from PRICE_FEATURES
    bounds = session.sql("SELECT MIN(TS) AS MN, MAX(TS) AS MX FROM PRICE_FEATURES").collect()[0]
    min_ts = pd.to_datetime(bounds["MN"]) if bounds["MN"] is not None else pd.Timestamp.today() - pd.Timedelta(days=90)
    max_ts = pd.to_datetime(bounds["MX"]) if bounds["MX"] is not None else pd.Timestamp.today()
    start_date = st.date_input("Start date", value=(max_ts - pd.Timedelta(days=30)).date(), min_value=min_ts.date(), max_value=max_ts.date())
    end_date = st.date_input("End date", value=max_ts.date(), min_value=min_ts.date(), max_value=max_ts.date())

    run_button = st.button("Update view")

# Helper to fetch predictions
def load_existing_predictions(sym: str, start_d: pd.Timestamp, end_d: pd.Timestamp) -> pd.DataFrame:
    sp = (
        session.table('PREDICTIONS_SP500_RET3M')
               .filter((col('TICKER') == sym) & (col('TS') >= pd.Timestamp(start_d)) & (col('TS') <= pd.Timestamp(end_d)))
               .sort(col('TS'))
    )
    return sp.to_pandas()

# Helper to on-demand score via registry
def score_on_demand(sym: str, start_d: pd.Timestamp, end_d: pd.Timestamp, version: str) -> pd.DataFrame:
    from snowflake.ml.registry import Registry
    feature_cols = ['RET_1','SMA_5','SMA_20','VOL_20','RSI_PROXY','VOLUME','CLOSE']
    reg = Registry(session=session, database_name='SP500_STOCK_DEMO', schema_name='DATA')
    mv = reg.get_model('XGB_SP500_RET3M').version(version)
    feats = (
        session.table('PRICE_FEATURES')
               .filter((col('TICKER') == sym) & (col('TS') >= pd.Timestamp(start_d)) & (col('TS') <= pd.Timestamp(end_d)))
               .sort(col('TS'))
    )
    preds_sp = mv.run(feats, function_name='PREDICT')
    return preds_sp.to_pandas()

# Main panels
tab_overview, tab_preds, tab_drift, tab_explain = st.tabs(["Overview", "Predictions", "Drift", "Explainability"])

# Compute datasets for the selected controls
if run_button or True:
    start_ts = pd.to_datetime(start_date)
    end_ts = pd.to_datetime(end_date) + pd.Timedelta(days=0, hours=23, minutes=59)

    # Load predictions
    try:
        if source_mode == "Existing predictions":
            preds_pd = load_existing_predictions(ticker, start_ts, end_ts)
        else:
            preds_pd = score_on_demand(ticker, start_ts, end_ts, selected_version)
    except Exception as e:
        preds_pd = pd.DataFrame()
        st.warning(f"Could not load predictions: {e}")

    # Join with CLOSE for context
    try:
        feats_pd = (
            session.table('PRICE_FEATURES')
                   .filter((col('TICKER') == ticker) & (col('TS') >= pd.Timestamp(start_ts)) & (col('TS') <= pd.Timestamp(end_ts)))
                   .select('TICKER','TS','CLOSE')
                   .sort(col('TS'))
                   .to_pandas()
        )
    except Exception:
        feats_pd = pd.DataFrame(columns=['TICKER','TS','CLOSE'])

    merged = preds_pd.merge(feats_pd, on=['TICKER','TS'], how='left') if not preds_pd.empty else feats_pd

    # Overview KPIs
    with tab_overview:
        c1, c2, c3, c4 = st.columns(4)
        num_rows = int(len(merged)) if merged is not None else 0
        avg_pred = float(merged['PREDICTED_RETURN'].mean()) if 'PREDICTED_RETURN' in merged else 0.0
        std_pred = float(merged['PREDICTED_RETURN'].std()) if 'PREDICTED_RETURN' in merged else 0.0
        c1.metric("Rows", f"{num_rows:,}")
        c2.metric("Avg predicted", f"{avg_pred:.5f}")
        c3.metric("Std predicted", f"{std_pred:.5f}")
        c4.metric("Model version", selected_version)
        st.divider()
        st.subheader(f"{ticker} — Predictions (selected window)")
        if not merged.empty and 'PREDICTED_RETURN' in merged:
            chart_df = merged[['TS','PREDICTED_RETURN']].set_index('TS')
            st.line_chart(chart_df)
        else:
            st.info("No predictions available for the selection.")

    # Predictions table and close context
    with tab_preds:
        st.subheader("Detail table")
        if not merged.empty:
            st.dataframe(merged.sort_values('TS').reset_index(drop=True))
            st.subheader("Close price context")
            if 'CLOSE' in merged:
                st.line_chart(merged[['TS','CLOSE']].set_index('TS'))
        else:
            st.info("No data to display for current filters.")

    # Drift panel: show latest PSI table if exists
    with tab_drift:
        st.subheader("Recent feature drift (PSI)")
        try:
            psi_pd = session.table('DRIFT_PSI_SP500').to_pandas()
            if not psi_pd.empty:
                st.dataframe(psi_pd.sort_values('FEATURE').reset_index(drop=True))
            else:
                st.info("PSI table is empty.")
        except Exception:
            st.info("PSI table not found. Run the inference/monitoring notebook to generate it.")

    # Explainability: show global SHAP importances if logged
    with tab_explain:
        st.subheader("Global feature importance (mean |SHAP|)")
        try:
            shap_pd = session.table('FEATURE_SHAP_GLOBAL_TOP').to_pandas()
            if not shap_pd.empty:
                topn = shap_pd.sort_values('mean_abs_shap', ascending=False).head(15)
                st.bar_chart(topn.set_index('feature')['mean_abs_shap'])
                st.dataframe(topn.reset_index(drop=True))
            else:
                st.info("No SHAP importance table found.")
        except Exception:
            st.info("No SHAP importance table found.")
