# Market Regime-Switching Transformer (Colab)

This notebook clones the repo, installs dependencies, trains the model, evaluates it, and visualizes results. GPU is used automatically if available.

In [None]:
# Environment check
!nvidia-smi || echo "No GPU detected; running on CPU."

In [None]:
# Clone repository
!git clone https://github.com/Rohanjain2312/market-regime-transformer-codex.git
%cd market-regime-transformer-codex/market_regime_transformer

In [None]:
# Install dependencies
!pip install -r requirements.txt

In [None]:
# Optional: set a custom seed
import src.config as config
cfg = config.get_config(seed=42)
cfg

In [None]:
# Train the model (rolling validation, checkpoints saved under data/processed)
!python -m src.train

In [None]:
# Evaluate the first split checkpoint and save confusion matrix
!python -m src.evaluate

In [None]:
# Visualize regime transitions from the engineered dataset
import pandas as pd
from pathlib import Path
from src.data_loader import load_data
from src.features import build_feature_windows
from src.config import get_config
from src.visualize import plot_regime_transitions

cfg = get_config()
raw = load_data(cfg)
X, y_reg, y_cls = build_feature_windows(raw, cfg, target_col="SPY")
plot_path = cfg.processed_data_path / "regime_transitions.png"
plot_regime_transitions(y_cls, plot_path)
print(f"Saved regime transitions plot to {plot_path}")

In [None]:
# (Optional) Visualize attention weights using a trained model
import torch
from src.model import RegimeTransformer
from src.visualize import plot_attention

feature_dim = X.shape[-1]
model = RegimeTransformer(
    input_dim=feature_dim,
    d_model=cfg.embedding_dim,
    nhead=cfg.num_heads,
    num_layers=cfg.num_layers,
    dim_feedforward=cfg.embedding_dim * 2,
    dropout=cfg.dropout,
    num_regimes=2,
).to(cfg.device)

ckpt = cfg.processed_data_path / "best_model_split0.pt"
if ckpt.exists():
    model.load_state_dict(torch.load(ckpt, map_location=cfg.device))
    model.eval()
    # Use a small batch to extract attention
    sample = torch.tensor(X[:4], dtype=torch.float32).to(cfg.device)
    _ = model(sample)
    if hasattr(model, "encoder"):
        # Access encoder layer attention via hooks not exposed here; this is a placeholder for custom hooks.
        print("Model run complete. Add attention hooks in model if detailed maps are needed.")
else:
    print(f"Checkpoint {ckpt} not found. Train first.")