In [None]:
import logging
import sys
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import seaborn as sns
from hmmlearn.hmm import GaussianHMM
from scipy.optimize import linear_sum_assignment
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm


PROJECT_ROOT_PATH = Path.cwd().parent
if str(PROJECT_ROOT_PATH) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT_PATH))

from regime_predictor_lib.utils.database_manager import DatabaseManager

plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("talk")
plt.rcParams["figure.figsize"] = (20, 8)
plt.rcParams["figure.dpi"] = 100
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("HMM_Stability_Analysis")

DB_PATH = PROJECT_ROOT_PATH / "data" / "db" / "volume" / "quant.db"
FEATURES_TO_USE = ["ret_126d", "log_vol_21d"]
SP500_PRICE_COL = "sp500_adjusted_close"
N_HMM_STATES = 3
HMM_N_ITER = 1000
HMM_TOL = 1e-3
SMOOTHING_WINDOW = 200

db_manager = DatabaseManager(db_path=DB_PATH)

In [None]:
def train_hmm_model(
    features_df: pd.DataFrame, random_state: int
) -> tuple[GaussianHMM, pd.Series]:

    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features_df)

    hmm_model = GaussianHMM(
        n_components=N_HMM_STATES,
        covariance_type="diag",
        n_iter=HMM_N_ITER,
        random_state=random_state,
        tol=HMM_TOL,
    )

    hmm_model.fit(features_scaled)
    predicted_states_raw = hmm_model.predict(features_scaled)

    hmm_states_series_raw = pd.Series(
        predicted_states_raw, index=features_df.index, name="hmm_state_raw"
    )
    smoothed_states = (
        hmm_states_series_raw.rolling(window=SMOOTHING_WINDOW, center=True, min_periods=1)
        .apply(lambda x: x.mode()[0], raw=False)
        .bfill()
        .ffill()
        .astype(int)
    )

    return hmm_model, smoothed_states


def map_regimes(model1_means: np.ndarray, model2_means: np.ndarray) -> dict:
    distance_matrix = np.zeros((N_HMM_STATES, N_HMM_STATES))
    for i in range(N_HMM_STATES):
        for j in range(N_HMM_STATES):
            distance_matrix[i, j] = np.linalg.norm(model1_means[i] - model2_means[j])

    row_ind, col_ind = linear_sum_assignment(distance_matrix)
    return dict(zip(row_ind, col_ind))

def plot_single_regime_price_chart(ax, price_series, regime_series, title, n_states=N_HMM_STATES):
    colors = sns.color_palette("tab10", n_states)
    
    ax.plot(price_series.index, price_series.values, color='lightgray', lw=1, alpha=0.5, zorder=0)
    aligned_regimes = regime_series.reindex(price_series.index).ffill().bfill()

    for i in range(len(price_series) - 1):
        current_regime_val = aligned_regimes.iloc[i]
        try:
            segment_color = colors[int(current_regime_val)]
        except (ValueError, IndexError):
            segment_color = "gray"
            
        ax.plot(
            price_series.index[i:i+2],
            price_series.iloc[i:i+2],
            color=segment_color,
            lw=1.5
        )
    
    ax.set_yscale('log')
    ax.set_title(title, fontsize=14)
    ax.set_ylabel("Price (Log Scale)")
    ax.grid(True, which='both', linestyle='--', alpha=0.6)
    
    legend_elements = [Line2D([0], [0], color=colors[i], lw=2, label=f'Regime {i}') for i in range(n_states)]
    ax.legend(handles=legend_elements, loc='upper left')

In [None]:
try:
    with db_manager.engine.connect() as connection:
        cols_to_fetch = list(set(FEATURES_TO_USE + [SP500_PRICE_COL]))
        select_cols_str = ", ".join([f'"{col}"' for col in cols_to_fetch])
        query = f"""
            SELECT date, {select_cols_str}
            FROM sp500_derived_indicators
            ORDER BY date ASC;
        """
        data_df_full = pd.read_sql_query(query, connection, parse_dates=["date"])
    data_df_full.set_index("date", inplace=True)
    logger.info(f"Successfully loaded {len(data_df_full)} S&P 500 records.")

except Exception as e:
    logger.error(f"Failed to load data: {e}")
    data_df_full = pd.DataFrame()

features_df = data_df_full[FEATURES_TO_USE].dropna().copy()
logger.info(f"Feature set prepared with {len(features_df)} rows after dropping NaNs.")

price_df = data_df_full.loc[features_df.index, SP500_PRICE_COL].copy()

In [None]:
N_RUNS = 10  
stochastic_results = {}

logger.info(f"Starting stochastic stability analysis with {N_RUNS} runs...")

for i in tqdm(range(N_RUNS), desc="Stochastic Runs"):
    random_state = i * 42 
    model, states = train_hmm_model(features_df, random_state=random_state)
    stochastic_results[f"run_{i}"] = {"model": model, "states": states}

logger.info("Stochastic stability experiment finished.")

In [None]:
ref_run = "run_0"
ref_model = stochastic_results[ref_run]["model"]
ref_states = stochastic_results[ref_run]["states"]

agreement_scores = {}
remapped_states_df = pd.DataFrame({ref_run: ref_states})

for run_name, run_data in stochastic_results.items():
    if run_name == ref_run:
        continue

    current_model = run_data["model"]
    current_states = run_data["states"]

    mapping = map_regimes(ref_model.means_, current_model.means_)
    logger.info(f"Mapping for {run_name}: {mapping}")

    remapped_states = current_states.map(
        {v: k for k, v in mapping.items()}
    )  
    remapped_states_df[run_name] = remapped_states

    agreement = (ref_states == remapped_states).mean()
    agreement_scores[run_name] = agreement
    logger.info(f"Agreement between {ref_run} and {run_name}: {agreement:.4f}")

print("\n--- Stochastic Stability Agreement Summary ---")
agreement_series = pd.Series(agreement_scores)
print(agreement_series.describe())

In [None]:
fig, ax = plt.subplots(figsize=(20, 8))
for col in remapped_states_df.columns:
    ax.plot(
        remapped_states_df.index,
        remapped_states_df[col],
        alpha=0.4,
        lw=1.5,
        label=col if col == "run_0" else None,
    )
ax.plot(
    remapped_states_df.index,
    remapped_states_df["run_0"],
    color="black",
    lw=2,
    label="Reference (run_0)",
)
ax.set_yticks(range(N_HMM_STATES))
ax.set_ylabel("Regime")
ax.set_title("Stochastic Stability: Regime Timelines from 10 Different Random Seeds")
ax.legend()
plt.show()

fig, axes = plt.subplots(
    1, len(FEATURES_TO_USE), figsize=(12 * len(FEATURES_TO_USE), 8), sharey=False
)
fig.suptitle("Stability of Regime Characteristics (Means)", fontsize=20)
for i, feature in enumerate(FEATURES_TO_USE):
    ax = axes[i]
    means_data = []
    for run_name, run_data in stochastic_results.items():
        for state_idx in range(N_HMM_STATES):
            means_data.append(
                {"run": run_name, "state": state_idx, "mean_value": run_data["model"].means_[state_idx, i]}
            )
    means_df = pd.DataFrame(means_data)
    sns.boxplot(data=means_df, x="state", y="mean_value", ax=ax)
    ax.set_title(f"Scaled Mean of '{feature}'")
    ax.set_xlabel("HMM State")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

### Cell 6a (NEW): Visualize Stochastic Stability with Individual Price Charts
This visualization provides a more direct, intuitive view of the stochastic stability. Each plot shows the S&P 500 price chart colored by the regimes identified in one specific run. If the models are stable, the color patterns across all 10 charts should look nearly identical. All states are mapped to the reference run (`run_0`) to ensure color consistency.

In [None]:
logger.info("Generating vertically stacked price charts for stochastic stability analysis...")

fig, axes = plt.subplots(
    nrows=N_RUNS,
    ncols=1,
    figsize=(20, N_RUNS * 5), 
    sharex=True
)

fig.suptitle('Stochastic Stability: S&P 500 Price Colored by Regime for Each Run', fontsize=24, y=0.995)

for i, run_name in enumerate(remapped_states_df.columns):
    ax = axes[i]
    run_states = remapped_states_df[run_name]
    title = f'Run {i} (Random State: {i*42}) - Agreement to Ref: {agreement_scores.get(run_name, 1.0):.4f}'
    
    plot_single_regime_price_chart(
        ax=ax, 
        price_series=price_df, 
        regime_series=run_states, 
        title=title
    )

axes[-1].set_xlabel("Date") 
plt.tight_layout(rect=[0, 0, 1, 0.98]) 
plt.show()

In [None]:
leave_out_periods = [700, 1400, 2100, 2800, 3500, 4200, 4900, 5600]  # Number of recent days to exclude
temporal_results = {}
FIXED_RANDOM_STATE = 42

logger.info("Starting temporal stability analysis...")

logger.info("Training model on full dataset...")
full_model, full_states = train_hmm_model(features_df, random_state=FIXED_RANDOM_STATE)
temporal_results["full_data"] = {"model": full_model, "states": full_states}

for periods_to_exclude in tqdm(leave_out_periods, desc="Temporal Runs"):
    run_name = f"exclude_last_{periods_to_exclude}d"
    logger.info(f"Training model for: {run_name}")
    
    truncated_features_df = features_df.iloc[:-periods_to_exclude]
    if len(truncated_features_df) < SMOOTHING_WINDOW * 2:
        logger.warning(f"Skipping run {run_name}, dataset too small after truncation.")
        continue
        
    model, states = train_hmm_model(truncated_features_df, random_state=FIXED_RANDOM_STATE)
    temporal_results[run_name] = {"model": model, "states": states}

logger.info("Temporal stability experiment finished.")

In [None]:
ref_run_temporal = "full_data"
ref_model_temporal = temporal_results[ref_run_temporal]["model"]
ref_states_temporal = temporal_results[ref_run_temporal]["states"]

temporal_agreement_scores = {}

for run_name, run_data in temporal_results.items():
    if run_name == ref_run_temporal:
        continue

    current_model = run_data["model"]
    current_states = run_data["states"]

    mapping = map_regimes(ref_model_temporal.means_, current_model.means_)
    logger.info(f"Mapping for {run_name}: {mapping}")

    remapped_states = current_states.map({v: k for k, v in mapping.items()})

    common_index = ref_states_temporal.index.intersection(remapped_states.index)
    agreement = (ref_states_temporal.loc[common_index] == remapped_states.loc[common_index]).mean()
    temporal_agreement_scores[run_name] = agreement
    logger.info(f"Agreement between {ref_run_temporal} and {run_name}: {agreement:.4f}")

print("\n--- Temporal Stability Agreement Summary ---")
print(pd.Series(temporal_agreement_scores))

In [None]:
fig, ax = plt.subplots(figsize=(20, 8))
ax.plot(
    ref_states_temporal.index,
    ref_states_temporal,
    lw=2,
    label="Full Data Model",
    color="black",
    drawstyle="steps-post",
)
last_truncated_run = list(temporal_results.keys())[-1]
last_truncated_states = temporal_results[last_truncated_run]["states"]
last_mapping = map_regimes(ref_model_temporal.means_, temporal_results[last_truncated_run]["model"].means_)
last_remapped = last_truncated_states.map({v: k for k, v in last_mapping.items()})

ax.plot(
    last_remapped.index,
    last_remapped,
    lw=2,
    alpha=0.7,
    label=f"Truncated Model ({last_truncated_run})",
    color="red",
    drawstyle="steps-post",
)
ax.set_yticks(range(N_HMM_STATES))
ax.set_ylabel("Regime")
ax.set_title("Temporal Stability: Full Model vs. Truncated Model")
ax.legend()
plt.show()

fig, axes = plt.subplots(
    1, len(FEATURES_TO_USE), figsize=(12 * len(FEATURES_TO_USE), 8), sharey=False
)
fig.suptitle("Temporal Stability of Regime Characteristics (Means)", fontsize=20)

temporal_means_data = []
for run_name, run_data in temporal_results.items():
    mapping = map_regimes(ref_model_temporal.means_, run_data["model"].means_)
    for state_idx in range(N_HMM_STATES):
        mapped_state = mapping.get(state_idx, -1) 
        for feat_idx, feature in enumerate(FEATURES_TO_USE):
             temporal_means_data.append({
                "run": run_name,
                "state": mapped_state, 
                "mean_value": run_data["model"].means_[state_idx, feat_idx]
             })

temporal_means_df = pd.DataFrame(temporal_means_data)

for i, feature in enumerate(FEATURES_TO_USE):
    ax = axes[i]
    sns.boxplot(data=temporal_means_df[temporal_means_df["mean_value"].notna()], x="state", y="mean_value", hue="run", ax=ax)
    ax.set_title(f"Scaled Mean of '{feature}'")
    ax.set_xlabel("Mapped HMM State (Ref: Full Data)")
    ax.legend(title="Model Run")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

### Cell 9a (NEW): Visualize Temporal Stability with Individual Price Charts
This visualization shows the impact of training the HMM on different windows of historical data. The top plot uses all available data, and each subsequent plot excludes more of the most recent data. If the model is temporally stable, the regime classifications for the older, overlapping periods should remain consistent across the plots.

In [None]:
logger.info("Generating vertically stacked price charts for temporal stability analysis...")

n_temporal_runs = len(temporal_results)
fig, axes = plt.subplots(
    nrows=n_temporal_runs,
    ncols=1,
    figsize=(20, n_temporal_runs * 5),
    sharex=False 
)

fig.suptitle('Temporal Stability: S&P 500 Price Colored by Regime for Different Training Windows', fontsize=24, y=0.995)

run_order = ['full_data'] + [name for name in temporal_results if name != 'full_data']

for i, run_name in enumerate(run_order):
    ax = axes[i]
    run_data = temporal_results[run_name]
    model = run_data['model']
    states = run_data['states']
    
    if run_name == ref_run_temporal:
        remapped_states = states
        title = f'Reference Model (Full Data)'
    else:
        mapping = map_regimes(ref_model_temporal.means_, model.means_)
        remapped_states = states.map({v: k for k, v in mapping.items()})
        agreement = temporal_agreement_scores.get(run_name, np.nan)
        title = f'Model Trained on Data Excluding Last {run_name.split("_")[-1]} (Agreement: {agreement:.4f})'
        
    price_to_plot = price_df.loc[states.index]
    
    plot_single_regime_price_chart(
        ax=ax,
        price_series=price_to_plot,
        regime_series=remapped_states,
        title=title
    )

axes[-1].set_xlabel("Date")
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()