# Training Monitoring — BD-Gen MDLM

This notebook loads training run data from wandb (or local Hydra outputs)
and visualises loss curves, learning rate schedules, per-class accuracy,
and generated sample statistics.

**Prerequisites:**
- A completed (or in-progress) training run logged to wandb.
- `pip install wandb matplotlib` (included in `[dev]` extras).

In [None]:
import matplotlib.pyplot as plt
import wandb

# Replace with your wandb entity and run ID
ENTITY = None  # e.g. "your-username"
PROJECT = "bd-generation"
RUN_ID = "<run_id>"  # from wandb dashboard

In [None]:
# Fetch run history from wandb API
api = wandb.Api()
run_path = f"{ENTITY}/{PROJECT}/{RUN_ID}" if ENTITY else f"{PROJECT}/{RUN_ID}"
run = api.run(run_path)
history = run.history(pandas=True)
print(f"Run: {run.name}")
print(f"Config: {run.config}")
print(f"Columns: {list(history.columns)}")
history.head()

In [None]:
# Plot training loss over steps
fig, ax = plt.subplots(figsize=(10, 4))
loss_data = history.dropna(subset=["train/loss"])
ax.plot(loss_data["_step"], loss_data["train/loss"], alpha=0.3, label="per-step")
# Rolling average for smoother curve
ax.plot(
    loss_data["_step"],
    loss_data["train/loss"].rolling(window=50, min_periods=1).mean(),
    label="rolling avg (50)",
)
ax.set_xlabel("Step")
ax.set_ylabel("ELBO Loss")
ax.set_title("Training Loss")
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Plot validation loss
val_data = history.dropna(subset=["val/loss"])
if not val_data.empty:
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(val_data["_step"], val_data["val/loss"], marker="o", label="val loss")
    ax.set_xlabel("Step")
    ax.set_ylabel("ELBO Loss")
    ax.set_title("Validation Loss")
    ax.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No validation data found.")

In [None]:
# Plot learning rate schedule
lr_data = history.dropna(subset=["train/lr"])
if not lr_data.empty:
    fig, ax = plt.subplots(figsize=(10, 3))
    ax.plot(lr_data["_step"], lr_data["train/lr"])
    ax.set_xlabel("Step")
    ax.set_ylabel("Learning Rate")
    ax.set_title("LR Schedule (linear warmup → constant)")
    plt.tight_layout()
    plt.show()

In [None]:
# Plot per-class accuracy (node and edge)
acc_cols = ["val/node_accuracy", "val/edge_accuracy"]
acc_data = history.dropna(subset=acc_cols, how="all")
if not acc_data.empty:
    fig, ax = plt.subplots(figsize=(10, 4))
    for col in acc_cols:
        data = acc_data.dropna(subset=[col])
        if not data.empty:
            ax.plot(data["_step"], data[col], marker="o", label=col)
    ax.set_xlabel("Step")
    ax.set_ylabel("Accuracy")
    ax.set_title("Per-Class Accuracy at Masked Positions")
    ax.set_ylim(0, 1)
    ax.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No accuracy data found.")

In [None]:
# Load a checkpoint and inspect model
import sys
from pathlib import Path

# Adjust this path to point to your BD_Generation directory
BD_GEN_ROOT = Path("../").resolve()
sys.path.insert(0, str(BD_GEN_ROOT))

import torch
from bd_gen.data.vocab import RPLAN_VOCAB_CONFIG
from bd_gen.model.denoiser import BDDenoiser
from bd_gen.utils.checkpoint import load_checkpoint

# Replace with actual checkpoint path
CKPT_PATH = "<path/to/checkpoint_final.pt>"

vc = RPLAN_VOCAB_CONFIG
model = BDDenoiser(
    d_model=128, n_layers=4, n_heads=4, vocab_config=vc, dropout=0.0,
)
meta = load_checkpoint(CKPT_PATH, model, device="cpu")
print(f"Checkpoint epoch: {meta['epoch']}")
print(f"Config: {meta['config']}")
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} ({n_params/1e6:.2f}M)")