In [1]:
%load_ext autoreload
%autoreload 2

In [127]:
import pandas as pd
from regime_ml.data.common.loaders import load_dataframe
from regime_ml.utils.config import load_configs
from regime_ml.regimes.hmm import HMMRegimeDetector
from regime_ml.features.macro.selection import get_top_features

In [128]:
macro_cfg = load_configs()["macro_data"]["regime_universe"]
feat_path = macro_cfg["raw_features_path"]

In [129]:
df_feat = load_dataframe(feat_path)
df_feat = df_feat.dropna() # drop burn-in period

In [141]:
split_date = '2020-01-01'

df_train = df_feat.loc[:split_date]
X_train = df_train.values
feature_names = df_train.columns.to_list()

df_test = df_feat.loc[split_date:]
X_test = df_test.values

X_full = df_feat.values

In [155]:
# --Initialisation--

n_regimes = 3

# 1. Emission Initialisation
from regime_ml.regimes.hmm import initialise_emissions
init_means, init_covars, scaler = initialise_emissions(df_train, n_clusters=n_regimes, covariance_type='full')

# 2. Transition Initialisation
from regime_ml.regimes.hmm import initialise_transitions
init_transmat = initialise_transitions(n_regimes=n_regimes, p_stay=0.99)

# 3. Initial State Probabilities 
from regime_ml.regimes.hmm import initialise_probabilities
init_startprob = initialise_probabilities(n_regimes=n_regimes)

In [156]:
init_transmat

array([[0.99 , 0.005, 0.005],
       [0.005, 0.99 , 0.005],
       [0.005, 0.005, 0.99 ]])

In [160]:
model = HMMRegimeDetector(
    n_regimes=n_regimes,
    covariance_type='full',
    startprob=init_startprob,
    transmat=init_transmat,
    means=init_means,
    covars=init_covars
    )

X_train_scaled = scaler.fit_transform(X_train)
X_full_scaled = scaler.fit_transform(X_full)

model.fit(X_train_scaled)
regimes = model.predict(X_full_scaled)
proba = model.predict_proba(X_full_scaled)

In [161]:
from regime_ml.regimes.visualisation import plot_regime_timeseries
from regime_ml.regimes.visualisation import plot_ticker_by_regime

fig1 = plot_ticker_by_regime('SPY', df_feat.index, regimes)
fig1.show()

fig2 = plot_regime_timeseries(df_feat, regimes, proba)
fig2.show()

Downloading SPY data from 2004-09-01 to 2026-01-23...
