# 07 Gate Visualization

Visualize LSTM/GRU gate values captured during model-training notebooks.

This notebook auto-discovers all files matching:
- `reports/predictions/*_gate_values.csv`

Examples include pure, feature-hybrid, and residual-hybrid gate outputs.


In [None]:
from __future__ import annotations

from pathlib import Path
import sys

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
try:
    import yfinance as yf
except ImportError:
    yf = None

PROJECT_ROOT = Path.cwd().resolve()
if not (PROJECT_ROOT / "src").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

sns.set_theme(style="whitegrid")
pd.set_option("display.max_columns", 120)


In [None]:
pred_dir = PROJECT_ROOT / "reports" / "predictions"
gate_files = sorted(pred_dir.glob("*_gate_values.csv"))

if not gate_files:
    raise FileNotFoundError(
        "No gate files found. Run notebooks 03/04/05 (and 05a if desired) first."
    )

print("Using gate files:")
for path in gate_files:
    print(" -", path.name)

gates_df = pd.concat([pd.read_csv(path, parse_dates=["date"]) for path in gate_files], ignore_index=True)
required_cols = {"date", "variant", "architecture", "lag", "gate_name", "gate_value_mean"}
missing_cols = sorted(required_cols.difference(gates_df.columns))
if missing_cols:
    raise ValueError(f"Gate files are missing required columns: {missing_cols}")

print(f"Loaded gate rows: {len(gates_df):,}")
gates_df.head()


In [None]:
summary_by_lag = (
    gates_df
    .groupby(["variant", "architecture", "gate_name", "lag"], as_index=False)["gate_value_mean"]
    .mean()
)
summary_by_lag.head()


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharey=False)

for ax, arch in zip(axes, ["lstm", "gru"]):
    plot_df = summary_by_lag[summary_by_lag["architecture"] == arch].copy()
    plot_df["series"] = plot_df["variant"].str.title() + " | " + plot_df["gate_name"]
    sns.lineplot(data=plot_df, x="lag", y="gate_value_mean", hue="series", ax=ax)
    ax.set_title(f"{arch.upper()} Gate Profiles by Lag")
    ax.set_xlabel("Lag (1 = most recent input day)")
    ax.set_ylabel("Mean Gate Activation")
    ax.invert_xaxis()

plt.tight_layout()
plt.show()


In [None]:
# Compare gate behavior in high-vol vs normal regimes using lag=1 gate values.
all_pred_files = sorted(pred_dir.glob("*_predictions.csv"))
pred_files = [
    path
    for path in all_pred_files
    if "raw_predictions" not in path.name
    and not path.name.startswith("evaluation_metrics")
]

if not pred_files:
    raise FileNotFoundError("No prediction files found for regime analysis.")

pred_df = pd.concat([pd.read_csv(path, parse_dates=["date"]) for path in pred_files], ignore_index=True)
required_pred_cols = {"date", "variant", "architecture", "y_true_var"}
missing_pred_cols = sorted(required_pred_cols.difference(pred_df.columns))
if missing_pred_cols:
    raise ValueError(f"Prediction files are missing required columns: {missing_pred_cols}")

threshold = pred_df["y_true_var"].quantile(0.90)
pred_df["vol_regime"] = pred_df["y_true_var"].ge(threshold).map({True: "high_vol", False: "normal"})

gate_last = gates_df[gates_df["lag"] == 1].copy()
gate_regime = gate_last.merge(
    pred_df[["date", "variant", "architecture", "vol_regime"]],
    on=["date", "variant", "architecture"],
    how="inner",
)

regime_summary = (
    gate_regime
    .groupby(["architecture", "variant", "gate_name", "vol_regime"], as_index=False)["gate_value_mean"]
    .mean()
)
regime_summary.head()


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharey=False)

for ax, arch in zip(axes, ["lstm", "gru"]):
    plot_df = regime_summary[regime_summary["architecture"] == arch]
    sns.barplot(data=plot_df, x="gate_name", y="gate_value_mean", hue="vol_regime", ax=ax)
    ax.set_title(f"{arch.upper()} Lag-1 Gate Activation by Regime")
    ax.set_xlabel("Gate")
    ax.set_ylabel("Mean Gate Activation")

plt.tight_layout()
plt.show()


In [None]:
# Build lag-1 daily gate series and merge with VIX for gate-vs-market plots.
gate_daily = (
    gates_df[gates_df["lag"] == 1]
    .groupby(["date", "architecture", "variant", "gate_name"], as_index=False)["gate_value_mean"]
    .mean()
)

vix_cache_path = PROJECT_ROOT / "data" / "processed" / "vix_daily.csv"
start_date = (gate_daily["date"].min() - pd.Timedelta(days=30)).strftime("%Y-%m-%d")
end_date = (gate_daily["date"].max() + pd.Timedelta(days=7)).strftime("%Y-%m-%d")

cached_vix = None
if vix_cache_path.exists():
    cached_vix = pd.read_csv(vix_cache_path, parse_dates=["date"])

try:
    if yf is None:
        raise ImportError("yfinance is not installed in this environment.")
    vix_raw = yf.download("^VIX", start=start_date, end=end_date, auto_adjust=False, progress=False)
    if vix_raw.empty:
        raise ValueError("VIX download returned empty dataframe.")

    if isinstance(vix_raw.columns, pd.MultiIndex):
        vix_raw.columns = [col[0] for col in vix_raw.columns]

    vix_raw = vix_raw.rename_axis("date").reset_index()
    vix_raw.columns = [str(c).strip().lower().replace(" ", "_") for c in vix_raw.columns]
    close_col = "adj_close" if "adj_close" in vix_raw.columns else "close"

    vix_df = vix_raw[["date", close_col]].rename(columns={close_col: "vix_close"})
    vix_df["date"] = pd.to_datetime(vix_df["date"], errors="raise").dt.tz_localize(None)
    vix_df = vix_df.dropna(subset=["vix_close"]).sort_values("date").reset_index(drop=True)
    vix_df.to_csv(vix_cache_path, index=False)
except Exception as exc:
    if cached_vix is None or cached_vix.empty:
        raise RuntimeError("Unable to load VIX data from yfinance and no cache exists.") from exc
    print(f"Using cached VIX data due to download error: {exc}")
    vix_df = cached_vix.copy()

gate_vix = gate_daily.merge(vix_df, on="date", how="inner")
print(f"Rows in gate+VIX panel: {len(gate_vix):,}")
gate_vix.head()


In [None]:
# Plot each gate separately with VIX (lag-1 gate values).
for arch in sorted(gate_vix["architecture"].unique()):
    for variant in sorted(gate_vix[gate_vix["architecture"] == arch]["variant"].unique()):
        subset = gate_vix[(gate_vix["architecture"] == arch) & (gate_vix["variant"] == variant)].copy()
        if subset.empty:
            continue

        gates = sorted(subset["gate_name"].unique())
        ncols = 2
        nrows = (len(gates) + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(16, 3.6 * nrows), sharex=True)
        axes = pd.Series(axes.ravel() if hasattr(axes, "ravel") else [axes])

        for ax, gate in zip(axes, gates):
            d = subset[subset["gate_name"] == gate]
            ax.plot(d["date"], d["gate_value_mean"], color="tab:blue", linewidth=1.2)
            ax.set_title(gate.replace("_", " ").title())
            ax.set_ylabel("Lag-1 Gate Value", color="tab:blue")
            ax.tick_params(axis="y", labelcolor="tab:blue")
            ax.grid(alpha=0.2)

            ax2 = ax.twinx()
            ax2.plot(d["date"], d["vix_close"], color="tab:orange", linewidth=1.0, alpha=0.75)
            ax2.set_ylabel("VIX", color="tab:orange")
            ax2.tick_params(axis="y", labelcolor="tab:orange")

        for ax in axes[len(gates):]:
            ax.axis("off")

        fig.suptitle(f"{arch.upper()} | {variant.title()}: Lag-1 Gate Values vs VIX", y=1.01)
        plt.tight_layout()
        plt.show()


In [None]:
# Build lag-aggregated gate series (mean over lags 1-20) and merge with VIX.
LAG_MIN = 1
LAG_MAX = 20

gate_lag120 = (
    gates_df[gates_df["lag"].between(LAG_MIN, LAG_MAX)]
    .groupby(["date", "architecture", "variant", "gate_name"], as_index=False)["gate_value_mean"]
    .mean()
    .rename(columns={"gate_value_mean": "gate_value_lag1_20_mean"})
)

if gate_lag120.empty:
    raise ValueError("No gate rows found for lags 1-20.")

gate_lag120_vix = gate_lag120.merge(vix_df[["date", "vix_close"]], on="date", how="inner")
if gate_lag120_vix.empty:
    raise ValueError("Lag-aggregated gate panel is empty after merging with VIX.")

lag_bucket_bins = [0, 5, 10, 20]
lag_bucket_labels = ["lag1_5", "lag6_10", "lag11_20"]
gate_bucket = (
    gates_df[gates_df["lag"].between(LAG_MIN, LAG_MAX)]
    .assign(
        lag_bucket=lambda d: pd.cut(
            d["lag"],
            bins=lag_bucket_bins,
            labels=lag_bucket_labels,
            include_lowest=True,
        )
    )
    .groupby(["date", "architecture", "variant", "gate_name", "lag_bucket"], as_index=False)["gate_value_mean"]
    .mean()
)

gate_bucket_vix = gate_bucket.merge(vix_df[["date", "vix_close"]], on="date", how="inner")

print(f"Rows in lag1-20 gate+VIX panel: {len(gate_lag120_vix):,}")
print(f"Rows in lag-bucket gate+VIX panel: {len(gate_bucket_vix):,}")
gate_lag120_vix.head()


In [None]:
# Plot each gate separately: mean lag-1..20 gate value vs VIX.
for arch in sorted(gate_lag120_vix["architecture"].unique()):
    variants = sorted(gate_lag120_vix[gate_lag120_vix["architecture"] == arch]["variant"].unique())
    for variant in variants:
        subset = gate_lag120_vix[(gate_lag120_vix["architecture"] == arch) & (gate_lag120_vix["variant"] == variant)].copy()
        if subset.empty:
            continue

        gates = sorted(subset["gate_name"].unique())
        ncols = 2
        nrows = (len(gates) + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(16, 3.6 * nrows), sharex=True)
        axes = pd.Series(axes.ravel() if hasattr(axes, "ravel") else [axes])

        for ax, gate in zip(axes, gates):
            d = subset[subset["gate_name"] == gate]
            ax.plot(d["date"], d["gate_value_lag1_20_mean"], color="tab:blue", linewidth=1.2)
            ax.set_title(gate.replace("_", " ").title())
            ax.set_ylabel("Gate Mean (Lags 1-20)", color="tab:blue")
            ax.tick_params(axis="y", labelcolor="tab:blue")
            ax.grid(alpha=0.2)

            ax2 = ax.twinx()
            ax2.plot(d["date"], d["vix_close"], color="tab:orange", linewidth=1.0, alpha=0.75)
            ax2.set_ylabel("VIX", color="tab:orange")
            ax2.tick_params(axis="y", labelcolor="tab:orange")

        for ax in axes[len(gates):]:
            ax.axis("off")

        fig.suptitle(f"{arch.upper()} | {variant.title()}: Mean Gate (Lags 1-20) vs VIX", y=1.01)
        plt.tight_layout()
        plt.show()


In [None]:
# Optional detail: lag-bucket gate values vs VIX to retain temporal structure.
for arch in sorted(gate_bucket_vix["architecture"].unique()):
    variants = sorted(gate_bucket_vix[gate_bucket_vix["architecture"] == arch]["variant"].unique())
    for variant in variants:
        subset = gate_bucket_vix[(gate_bucket_vix["architecture"] == arch) & (gate_bucket_vix["variant"] == variant)].copy()
        if subset.empty:
            continue

        gates = sorted(subset["gate_name"].unique())
        ncols = 2
        nrows = (len(gates) + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(16, 3.8 * nrows), sharex=True)
        axes = pd.Series(axes.ravel() if hasattr(axes, "ravel") else [axes])

        for ax, gate in zip(axes, gates):
            d = subset[subset["gate_name"] == gate].copy()
            sns.lineplot(
                data=d,
                x="date",
                y="gate_value_mean",
                hue="lag_bucket",
                hue_order=["lag1_5", "lag6_10", "lag11_20"],
                ax=ax,
            )
            ax.set_title(gate.replace("_", " ").title())
            ax.set_ylabel("Gate Value")
            ax.grid(alpha=0.2)

            ax2 = ax.twinx()
            ax2.plot(d["date"], d["vix_close"], color="tab:orange", linewidth=1.0, alpha=0.35)
            ax2.set_ylabel("VIX", color="tab:orange")
            ax2.tick_params(axis="y", labelcolor="tab:orange")

        for ax in axes[len(gates):]:
            ax.axis("off")

        fig.suptitle(f"{arch.upper()} | {variant.title()}: Lag-Bucket Gates vs VIX", y=1.01)
        plt.tight_layout()
        plt.show()


In [None]:
# Correlation summary: gate mean (lags 1-20) vs VIX by model and gate.
corr_rows = []
for (arch, variant, gate), d in gate_lag120_vix.groupby(["architecture", "variant", "gate_name"]):
    d = d.sort_values("date")
    pearson = d["gate_value_lag1_20_mean"].corr(d["vix_close"])
    spearman = d["gate_value_lag1_20_mean"].corr(d["vix_close"], method="spearman")
    gate_ma21 = d["gate_value_lag1_20_mean"].rolling(21, min_periods=21).mean()
    vix_ma21 = d["vix_close"].rolling(21, min_periods=21).mean()
    corr_ma21 = gate_ma21.corr(vix_ma21)
    corr_rows.append(
        {
            "architecture": arch,
            "variant": variant,
            "gate_name": gate,
            "pearson_corr": pearson,
            "spearman_corr": spearman,
            "corr_21d_ma": corr_ma21,
            "n_obs": len(d),
        }
    )

gate_vix_corr = pd.DataFrame(corr_rows).sort_values(["architecture", "variant", "gate_name"]).reset_index(drop=True)
gate_vix_corr


In [None]:
# Heatmap: Pearson correlation (gate mean lags 1-20 vs VIX).
for arch in sorted(gate_vix_corr["architecture"].unique()):
    d = gate_vix_corr[gate_vix_corr["architecture"] == arch].copy()
    d["model_name"] = d["variant"].str.title()
    pivot = d.pivot(index="gate_name", columns="model_name", values="pearson_corr")

    plt.figure(figsize=(8, 4.5))
    sns.heatmap(
        pivot,
        annot=True,
        fmt=".2f",
        cmap="RdBu_r",
        center=0.0,
        vmin=-1.0,
        vmax=1.0,
        linewidths=0.5,
    )
    plt.title(f"{arch.upper()} Gate-VIX Correlation (Pearson)")
    plt.xlabel("Variant")
    plt.ylabel("Gate")
    plt.tight_layout()
    plt.show()


In [None]:
# Scatter + fitted line: Pure LSTM gate means (lags 1-20) vs VIX.
pure_lstm = gate_lag120_vix[(gate_lag120_vix["architecture"] == "lstm") & (gate_lag120_vix["variant"] == "pure")].copy()
if pure_lstm.empty:
    raise ValueError("No pure LSTM rows found in lag1-20 gate panel.")

g = sns.lmplot(
    data=pure_lstm,
    x="vix_close",
    y="gate_value_lag1_20_mean",
    col="gate_name",
    col_wrap=2,
    height=3.6,
    aspect=1.3,
    scatter_kws={"alpha": 0.25, "s": 12},
    line_kws={"color": "tab:red", "linewidth": 2},
)
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle("Pure LSTM: Gate Mean (Lags 1-20) vs VIX")
plt.show()


In [None]:
# Bucket-level correlation summary: gate value vs VIX by lag bucket.
bucket_corr_rows = []
for (arch, variant, gate, lag_bucket), d in gate_bucket_vix.groupby(["architecture", "variant", "gate_name", "lag_bucket"]):
    d = d.sort_values("date")
    pearson = d["gate_value_mean"].corr(d["vix_close"])
    spearman = d["gate_value_mean"].corr(d["vix_close"], method="spearman")
    gate_ma21 = d["gate_value_mean"].rolling(21, min_periods=21).mean()
    vix_ma21 = d["vix_close"].rolling(21, min_periods=21).mean()
    corr_ma21 = gate_ma21.corr(vix_ma21)
    bucket_corr_rows.append(
        {
            "architecture": arch,
            "variant": variant,
            "gate_name": gate,
            "lag_bucket": str(lag_bucket),
            "pearson_corr": pearson,
            "spearman_corr": spearman,
            "corr_21d_ma": corr_ma21,
            "n_obs": len(d),
        }
    )

gate_bucket_corr = pd.DataFrame(bucket_corr_rows).sort_values(
    ["architecture", "variant", "gate_name", "lag_bucket"]
).reset_index(drop=True)
gate_bucket_corr


In [None]:
# Heatmap: bucket-level Pearson gate-VIX correlation by architecture (columns = variant | lag bucket).
plot_df = gate_bucket_corr.copy()
missing_n = plot_df["pearson_corr"].isna().sum()
if missing_n > 0:
    print(f"Skipping {missing_n} rows with NaN Pearson correlation (typically constant gate series).")

plot_df = plot_df.dropna(subset=["pearson_corr"]).copy()
if plot_df.empty:
    raise ValueError("All bucket-level Pearson correlations are NaN; no heatmap can be drawn.")

for arch in sorted(plot_df["architecture"].unique()):
    d = plot_df[plot_df["architecture"] == arch].copy()
    d["lag_bucket"] = pd.Categorical(d["lag_bucket"], categories=["lag1_5", "lag6_10", "lag11_20"], ordered=True)
    d = d.sort_values(["variant", "lag_bucket", "gate_name"])
    d["variant_bucket"] = d["variant"].str.title() + " | " + d["lag_bucket"].astype(str)

    pivot = d.pivot_table(
        index="gate_name",
        columns="variant_bucket",
        values="pearson_corr",
        aggfunc="mean",
    )

    if pivot.empty:
        print(f"No non-NaN data for architecture={arch}; skipping.")
        continue

    fig_w = max(10, 0.75 * len(pivot.columns))
    plt.figure(figsize=(fig_w, 4.8))
    sns.heatmap(
        pivot,
        annot=True,
        fmt=".2f",
        cmap="RdBu_r",
        center=0.0,
        vmin=-1.0,
        vmax=1.0,
        linewidths=0.5,
    )
    plt.title(f"{arch.upper()} Bucket Gate-VIX Correlation (Pearson)")
    plt.xlabel("Variant | Lag Bucket")
    plt.ylabel("Gate")
    plt.tight_layout()
    plt.show()


In [None]:
# Focus plot: Pure LSTM bucket-level correlations by gate.
pure_lstm_bucket_corr = gate_bucket_corr[(gate_bucket_corr["architecture"] == "lstm") & (gate_bucket_corr["variant"] == "pure")].copy()
if pure_lstm_bucket_corr.empty:
    raise ValueError("No Pure LSTM bucket-correlation rows found.")

plt.figure(figsize=(10, 4.8))
sns.barplot(
    data=pure_lstm_bucket_corr,
    x="gate_name",
    y="pearson_corr",
    hue="lag_bucket",
    hue_order=["lag1_5", "lag6_10", "lag11_20"],
)
plt.axhline(0.0, color="gray", linewidth=1.0, linestyle="--")
plt.title("Pure LSTM: Bucket-Level Gate-VIX Correlation (Pearson)")
plt.xlabel("Gate")
plt.ylabel("Correlation with VIX")
plt.tight_layout()
plt.show()


In [None]:
summary_path = pred_dir / "gate_summary_by_lag.csv"
regime_path = pred_dir / "gate_summary_by_regime_lag1.csv"
summary_by_lag.to_csv(summary_path, index=False)
regime_summary.to_csv(regime_path, index=False)

print(f"Saved lag summary: {summary_path}")
print(f"Saved regime summary: {regime_path}")

lag120_path = pred_dir / "gate_summary_lag1_20_by_date.csv"
bucket_path = pred_dir / "gate_summary_lag_buckets_by_date.csv"
gate_lag120_vix.to_csv(lag120_path, index=False)
gate_bucket_vix.to_csv(bucket_path, index=False)
print(f"Saved lag1-20 summary: {lag120_path}")
print(f"Saved lag-bucket summary: {bucket_path}")

corr_path = pred_dir / "gate_vix_correlation_summary.csv"
gate_vix_corr.to_csv(corr_path, index=False)
print(f"Saved gate-vix correlation summary: {corr_path}")

bucket_corr_path = pred_dir / "gate_vix_correlation_by_bucket.csv"
gate_bucket_corr.to_csv(bucket_corr_path, index=False)
print(f"Saved gate-vix bucket correlation summary: {bucket_corr_path}")
