# VAE Discovery: Latent Tactical State Exploration

Train a beta-VAE on combined match feature data, extract latent codes,
and correlate each latent dimension with observable features to produce
an interpretation table.

**Target environment:** Kaggle (2 x T4 GPUs via `DataParallel`)

**Sections:**
1. Setup & GPU detection
2. Data loading & preprocessing
3. VAE training (multi-GPU)
4. Training diagnostics
5. Latent code extraction
6. Latent-feature Pearson correlation
7. Interpretation table (top-5 features per dimension)
8. Latent space UMAP visualisation
9. Summary & next steps

In [None]:
from __future__ import annotations

import re
import sys
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import torch
import torch.nn as nn
from scipy.stats import pearsonr
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from umap import UMAP

# -- Project imports ------------------------------------------------
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))

from tactical.config import VAEConfig
from tactical.models.preprocessing import PreprocessingPipeline
from tactical.models.vae import TacticalVAEModule, vae_loss

sns.set_theme(style="whitegrid", font_scale=0.9)
plt.rcParams["figure.dpi"] = 120
warnings.filterwarnings("ignore", category=FutureWarning)

FEATURES_PATH = PROJECT_ROOT / "data" / "output" / "features.parquet"
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "vae_discovery"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

METADATA_COLS: set[str] = {
    "match_id",
    "team_id",
    "segment_type",
    "start_time",
    "end_time",
    "period",
    "match_minute",
}

TIER_PATTERN = re.compile(r"^t(\d+)_")
MAX_FEATURE_TIER = 2

---
## 1. Setup & GPU Detection

In [None]:
n_gpus = torch.cuda.device_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"PyTorch {torch.__version__}")
print(f"Device : {device}  |  GPUs available: {n_gpus}")
for i in range(n_gpus):
    name = torch.cuda.get_device_name(i)
    mem = torch.cuda.get_device_properties(i).total_mem / 1024**3
    print(f"  GPU {i}: {name}  ({mem:.1f} GB)")

---
## 2. Data Loading & Preprocessing

Load the combined feature matrix, filter to **window** segments and
Tier 1+2 features, then z-score normalise **without PCA** so that
every latent dimension can be correlated back to named features.

In [None]:
if FEATURES_PATH.exists():
    df_raw = pl.read_parquet(FEATURES_PATH)
    print(f"Loaded {FEATURES_PATH}")
else:
    print(
        f"{FEATURES_PATH} not found.\n"
        "Run `python scripts/run_feature_extraction.py` first."
    )
    raise SystemExit(1)

# Keep window segments only (consistent segment length)
df_win = df_raw.filter(pl.col("segment_type") == "window")

# Drop Tier 3+ feature columns (sparse 360 data)
tier3_cols = [
    c for c in df_win.columns
    if (m := TIER_PATTERN.match(c)) and int(m.group(1)) > MAX_FEATURE_TIER
]
df_win = df_win.drop(tier3_cols)

print(f"Window rows: {df_win.height:,}")
print(f"Dropped {len(tier3_cols)} Tier 3+ columns")

In [None]:
# Preprocess: z-score scaling, median imputation, NO PCA
pipeline = PreprocessingPipeline(
    feature_prefix="t",
    null_strategy="impute_median",
    pca_variance_threshold=None,
)
X = pipeline.fit_transform(df_win)
feature_names: list[str] = pipeline._feature_columns
retained_mask = pipeline.get_retained_row_mask(df_win)
df_retained = df_win.filter(pl.Series(retained_mask))

print(f"Samples : {X.shape[0]:,}")
print(f"Features: {X.shape[1]} ({len(feature_names)} named columns)")

---
## 3. VAE Training (Multi-GPU)

Architecture from spec: beta-VAE with `(256, 128, 64)` encoder,
`ReduceLROnPlateau` scheduler, wrapped in `DataParallel` for 2 x T4.

In [None]:
cfg = VAEConfig(
    latent_dim=8,
    hidden_dims=(256, 128, 64),
    beta=4.0,
    learning_rate=1e-3,
    batch_size=512,
    n_epochs=150,
    dropout=0.2,
    random_state=42,
)
print(cfg)

In [None]:
torch.manual_seed(cfg.random_state)
np.random.seed(cfg.random_state)  # noqa: NPY002

input_dim = X.shape[1]
module = TacticalVAEModule(
    input_dim=input_dim,
    latent_dim=cfg.latent_dim,
    hidden_dims=cfg.hidden_dims,
    dropout=cfg.dropout,
).to(device)

# Wrap with DataParallel when multiple GPUs are available
if n_gpus > 1:
    module_dp: nn.Module = nn.DataParallel(module)
    print(f"DataParallel enabled across {n_gpus} GPUs")
else:
    module_dp = module
    print("Single-device training")

tensor_X = torch.as_tensor(X, dtype=torch.float32, device=device)
dataset = TensorDataset(tensor_X)
loader = DataLoader(
    dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    drop_last=False,
)

optimiser = torch.optim.Adam(module.parameters(), lr=cfg.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimiser, mode="min", factor=0.5, patience=10,
)

In [None]:
epoch_losses: list[float] = []
epoch_recon_losses: list[float] = []
epoch_kl_losses: list[float] = []

module_dp.train()

for epoch in tqdm(range(cfg.n_epochs), desc="VAE training"):
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0
    n_batches = 0

    for (batch,) in loader:
        recon, mu, logvar = module_dp(batch)
        loss, recon_l, kl_l = vae_loss(recon, batch, mu, logvar, beta=cfg.beta)

        optimiser.zero_grad()
        loss.backward()  # type: ignore[no-untyped-call]
        optimiser.step()

        total_loss += loss.item()
        total_recon += recon_l.item()
        total_kl += kl_l.item()
        n_batches += 1

    mean_loss = total_loss / max(n_batches, 1)
    mean_recon = total_recon / max(n_batches, 1)
    mean_kl = total_kl / max(n_batches, 1)
    epoch_losses.append(mean_loss)
    epoch_recon_losses.append(mean_recon)
    epoch_kl_losses.append(mean_kl)
    scheduler.step(mean_loss)

module_dp.eval()
print(f"Final loss: {epoch_losses[-1]:.6f}  (recon={epoch_recon_losses[-1]:.6f}, kl={epoch_kl_losses[-1]:.6f})")

---
## 4. Training Diagnostics

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
epochs_arr = np.arange(1, len(epoch_losses) + 1)

axes[0].plot(epochs_arr, epoch_losses, linewidth=1.2)
axes[0].set_title("Total Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")

axes[1].plot(epochs_arr, epoch_recon_losses, linewidth=1.2, color="#dd8452")
axes[1].set_title("Reconstruction Loss (MSE)")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")

axes[2].plot(epochs_arr, epoch_kl_losses, linewidth=1.2, color="#55a868")
axes[2].set_title("KL Divergence")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Loss")

fig.suptitle("VAE Training Curves", fontsize=13, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "vae_training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

---
## 5. Latent Code Extraction

Extract the **mu** vectors (posterior means) in batches of 2048
to avoid GPU OOM.

In [None]:
ENCODE_BATCH = 2048

encode_loader = DataLoader(
    TensorDataset(tensor_X),
    batch_size=ENCODE_BATCH,
    shuffle=False,
)

latent_parts: list[np.ndarray] = []
module.eval()  # use unwrapped module for encoding (single-device, deterministic)
with torch.no_grad():
    for (batch,) in tqdm(encode_loader, desc="Encoding latent codes"):
        mu, _ = module.encode(batch)
        latent_parts.append(mu.cpu().numpy())

Z = np.concatenate(latent_parts, axis=0)
print(f"Latent codes: {Z.shape}  (samples x latent_dim)")

---
## 6. Latent-Feature Pearson Correlation

For each latent dimension $z_i$, compute the Pearson correlation
with every observable feature.  Since Pearson $r$ is invariant to
affine transformations, correlations are identical whether computed
on raw or z-scored features.

In [None]:
n_latent = Z.shape[1]
n_features = len(feature_names)

corr_matrix = np.empty((n_latent, n_features), dtype=np.float64)
pval_matrix = np.empty((n_latent, n_features), dtype=np.float64)

for zi in tqdm(range(n_latent), desc="Computing correlations"):
    z_col = Z[:, zi]
    for fi in range(n_features):
        r, p = pearsonr(z_col, X[:, fi])
        corr_matrix[zi, fi] = r
        pval_matrix[zi, fi] = p

print(f"Correlation matrix: {corr_matrix.shape}")

In [None]:
# Heatmap of full correlation matrix
fig, ax = plt.subplots(figsize=(max(14, n_features * 0.35), max(4, n_latent * 0.6)))
im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")

ax.set_yticks(range(n_latent))
ax.set_yticklabels([f"z{i}" for i in range(n_latent)])
ax.set_xticks(range(n_features))
ax.set_xticklabels(feature_names, rotation=90, fontsize=6)
ax.set_title("Pearson Correlation: Latent Dimensions vs Features")
fig.colorbar(im, ax=ax, label="Pearson r", shrink=0.8)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "vae_correlation_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()

---
## 7. Interpretation Table (Top-5 Features per Latent Dimension)

For each $z_i$, rank features by absolute Pearson $|r|$ and report
the five strongest correlates with direction and significance.

In [None]:
TOP_K = 5

table_rows: list[dict[str, object]] = []
for zi in range(n_latent):
    abs_corr = np.abs(corr_matrix[zi])
    top_idx = np.argsort(abs_corr)[::-1][:TOP_K]
    for rank, fi in enumerate(top_idx, start=1):
        table_rows.append({
            "latent_dim": f"z{zi}",
            "rank": rank,
            "feature": feature_names[fi],
            "pearson_r": round(float(corr_matrix[zi, fi]), 4),
            "abs_r": round(float(abs_corr[fi]), 4),
            "p_value": float(pval_matrix[zi, fi]),
        })

df_interp = pl.DataFrame(table_rows)
print(df_interp)

In [None]:
# Compact pivot view: one row per latent dim, columns = rank
pivot_rows: list[dict[str, object]] = []
for zi in range(n_latent):
    dim_df = df_interp.filter(pl.col("latent_dim") == f"z{zi}").sort("rank")
    row: dict[str, object] = {"dim": f"z{zi}"}
    for rank_i in range(TOP_K):
        feat = dim_df["feature"][rank_i]
        r_val = dim_df["pearson_r"][rank_i]
        row[f"#{rank_i + 1}"] = f"{feat} ({r_val:+.3f})"
    pivot_rows.append(row)

df_pivot = pl.DataFrame(pivot_rows)
print(df_pivot)

In [None]:
# Save artifacts
df_interp.write_parquet(OUTPUT_DIR / "latent_interpretation.parquet")
df_pivot.write_csv(OUTPUT_DIR / "latent_interpretation_pivot.csv")
print(f"Saved interpretation tables to {OUTPUT_DIR}")

In [None]:
# Bar chart: top-5 absolute correlations per latent dim
fig, axes = plt.subplots(
    2, (n_latent + 1) // 2,
    figsize=(4 * ((n_latent + 1) // 2), 7),
    sharey=False,
)
axes_flat = axes.flatten()

for zi in range(n_latent):
    ax = axes_flat[zi]
    dim_df = df_interp.filter(pl.col("latent_dim") == f"z{zi}").sort("rank")
    features_top = dim_df["feature"].to_list()
    r_vals = dim_df["pearson_r"].to_list()
    colors = ["#c44e52" if r < 0 else "#4c72b0" for r in r_vals]
    ax.barh(range(TOP_K - 1, -1, -1), r_vals, color=colors)
    ax.set_yticks(range(TOP_K - 1, -1, -1))
    ax.set_yticklabels(features_top, fontsize=7)
    ax.set_xlim(-1, 1)
    ax.axvline(0, color="black", linewidth=0.5)
    ax.set_title(f"z{zi}", fontsize=10)
    ax.set_xlabel("Pearson r", fontsize=8)

# Hide unused axes
for i in range(n_latent, len(axes_flat)):
    axes_flat[i].set_visible(False)

fig.suptitle("Top-5 Feature Correlates per Latent Dimension", fontsize=13, y=1.01)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "vae_top5_correlations.png", dpi=150, bbox_inches="tight")
plt.show()

---
## 8. Latent Space UMAP Visualisation

Project the 8-D latent codes to 2-D with UMAP, colouring points
by the feature most correlated with each sample's dominant latent
dimension.

In [None]:
umap_model = UMAP(n_components=2, random_state=cfg.random_state, n_jobs=-1)
Z_2d = umap_model.fit_transform(Z)
print(f"UMAP embedding: {Z_2d.shape}")

In [None]:
# Colour by the latent dimension with highest absolute activation
dominant_dim = np.argmax(np.abs(Z), axis=1)

fig, ax = plt.subplots(figsize=(9, 7))
scatter = ax.scatter(
    Z_2d[:, 0], Z_2d[:, 1],
    c=dominant_dim,
    cmap="tab10",
    s=4,
    alpha=0.5,
    rasterized=True,
)
cbar = fig.colorbar(scatter, ax=ax, label="Dominant Latent Dim")
cbar.set_ticks(range(n_latent))
cbar.set_ticklabels([f"z{i}" for i in range(n_latent)])
ax.set_title("UMAP of VAE Latent Space (coloured by dominant dimension)")
ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "vae_umap_latent.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Per-dimension activation maps
fig, axes = plt.subplots(
    2, (n_latent + 1) // 2,
    figsize=(4 * ((n_latent + 1) // 2), 7),
)
axes_flat = axes.flatten()

for zi in range(n_latent):
    ax = axes_flat[zi]
    sc = ax.scatter(
        Z_2d[:, 0], Z_2d[:, 1],
        c=Z[:, zi],
        cmap="coolwarm",
        s=2,
        alpha=0.4,
        rasterized=True,
    )
    ax.set_title(f"z{zi}", fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    fig.colorbar(sc, ax=ax, shrink=0.7)

for i in range(n_latent, len(axes_flat)):
    axes_flat[i].set_visible(False)

fig.suptitle("UMAP coloured by individual latent activations", fontsize=13, y=1.01)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "vae_umap_per_dim.png", dpi=150, bbox_inches="tight")
plt.show()

---
## 9. Save Model Artifacts

In [None]:
# Save trained VAE state dict + config for downstream use
model_path = OUTPUT_DIR / "vae_model.pt"
torch.save(
    {
        "state_dict": module.state_dict(),
        "input_dim": input_dim,
        "config_latent_dim": cfg.latent_dim,
        "config_hidden_dims": cfg.hidden_dims,
        "config_beta": cfg.beta,
        "config_learning_rate": cfg.learning_rate,
        "config_batch_size": cfg.batch_size,
        "config_n_epochs": cfg.n_epochs,
        "config_dropout": cfg.dropout,
        "config_random_state": cfg.random_state,
        "training_losses": epoch_losses,
    },
    model_path,
)

# Save latent codes for hybrid GMM/HMM experiments
np.save(OUTPUT_DIR / "latent_codes.npy", Z)

# Save preprocessing pipeline
pipeline.save(OUTPUT_DIR / "preprocessing_pipeline.pkl")

# Save correlation matrix
np.savez(
    OUTPUT_DIR / "correlation_matrix.npz",
    corr=corr_matrix,
    pval=pval_matrix,
    feature_names=np.array(feature_names),
)

print(f"All artifacts saved to {OUTPUT_DIR}")

---
## Summary & Next Steps

**Produced:**
- Trained beta-VAE (8-D latent space) on combined dataset
- Latent dimension interpretation table (top-5 correlated features per dim)
- UMAP visualisation of latent space structure
- Saved latent codes for downstream hybrid analysis (GMM on latent codes)

**Next steps (Task 5.5):**
- Fit GMM on VAE latent codes and compare to direct-feature GMM
- Compare state agreement, BIC, and silhouette between the two approaches
- Evaluate whether the VAE latent space yields more interpretable states