In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 1) INLINE SETUP & IMPORTS
# ──────────────────────────────────────────────────────────────────────────────
%matplotlib inline
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# add your model_AWP folder to the import path if needed:
import sys
sys.path.append("model_AWP")  

from AWP_inference import load_model, predict_regimes

# helper to extract contiguous segments
def get_segments(reg):
    changes = np.flatnonzero(reg[1:] != reg[:-1])
    starts  = np.concatenate(([0], changes+1))
    ends    = np.concatenate((changes, [len(reg)-1]))
    return list(zip(starts, ends, reg[starts]))

# vivid palettes
true_cmap = ListedColormap(["#ff0000","#808080","#00ff00"])
pred_cmap = ListedColormap(["#cc0000","#444444","#00cc00"])


# ──────────────────────────────────────────────────────────────────────────────
# 2) LOAD MODEL
# ──────────────────────────────────────────────────────────────────────────────
# (assumes model_AWP/bilstm_tagger.pth sits next to AWP_inference.py)
model, device = load_model("bilstm_tagger.pth")


# ──────────────────────────────────────────────────────────────────────────────
# 3) LOAD DATA
# ──────────────────────────────────────────────────────────────────────────────
# features + true labels
df = pd.read_csv("features_all_models4.csv")
df = (
    df
    .groupby("inst", group_keys=False)
    .apply(lambda g: g.iloc[100:])   # drop first-100 warmups
    .reset_index(drop=True)
)

# raw prices
price_df = pd.read_csv("prices.txt", sep=r"\s+", header=None)

# which columns are our features?
feat_cols = [c for c in df.columns if c not in ("inst","time","true_regime")]


# ──────────────────────────────────────────────────────────────────────────────
# 4) RUN INFERENCE & PLOT
# ──────────────────────────────────────────────────────────────────────────────
for inst in sorted(df["inst"].unique()):
    sub      = df[df["inst"]==inst].reset_index(drop=True)
    X_inst   = sub[feat_cols].values.astype(np.float32)   # (T, D)
    true_seq = sub["true_regime"].values                  # (T,)
    pred_seq = predict_regimes(model, device, X_inst)     # (T,)
    price    = price_df.iloc[100:100+len(sub), inst].values

    fig, (ax1, ax2) = plt.subplots(2,1, sharex=True, figsize=(12,5))

    # — True
    for s,e,lbl in get_segments(true_seq):
        ax1.axvspan(s, e, color=true_cmap(lbl), alpha=0.5, linewidth=0)
    ax1.plot(price, "k-", label="Price")
    ax1.set_title(f"Instrument {inst} — TRUE regimes")
    ax1.legend(loc="upper right")

    # — Predicted
    for s,e,lbl in get_segments(pred_seq):
        ax2.axvspan(s, e, color=pred_cmap(lbl), alpha=0.5, linewidth=0)
    ax2.plot(price, "k-", label="Price")
    ax2.set_title(f"Instrument {inst} — PREDICTED regimes")
    ax2.legend(loc="upper right")

    plt.tight_layout()
    plt.show()
