# 03 - Training Analysis

This notebook summarizes training dynamics across experiments:

- loss and balanced accuracy curves
- learning-rate schedule behavior
- runtime and memory comparison

In [None]:
from __future__ import annotations

from pathlib import Path
import json

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set_theme(style="whitegrid")
PROJECT_ROOT = Path.cwd().resolve().parents[0] if (Path.cwd() / "src").exists() is False else Path.cwd()
OUTPUT_DIR = PROJECT_ROOT / "outputs"
print("Using outputs dir:", OUTPUT_DIR)

In [None]:
# Example: load training history JSON files if available
history_files = sorted((OUTPUT_DIR / "history").glob("*.json")) if (OUTPUT_DIR / "history").exists() else []
print(f"Found {len(history_files)} history files")

histories = {}
for fp in history_files:
    with fp.open("r") as f:
        histories[fp.stem] = json.load(f)

if not histories:
    print("No training history files found yet. Run training first.")

In [None]:
# Plot loss and balanced accuracy curves for available runs
for run_name, hist in histories.items():
    epochs = list(range(1, len(hist.get("train_loss", [])) + 1))
    if not epochs:
        continue

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, hist.get("train_loss", []), label="train")
    plt.plot(epochs, hist.get("val_loss", []), label="val")
    plt.title(f"{run_name} - Loss")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, hist.get("train_balanced_accuracy", []), label="train")
    plt.plot(epochs, hist.get("val_balanced_accuracy", []), label="val")
    plt.title(f"{run_name} - Balanced Accuracy")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Optional: summarize runtime/gpu memory logs if available
summary_csv = OUTPUT_DIR / "training_summary.csv"
if summary_csv.exists():
    summary_df = pd.read_csv(summary_csv)
    display(summary_df)

    if "runtime_minutes" in summary_df.columns:
        plt.figure(figsize=(8, 4))
        sns.barplot(data=summary_df, x="run_name", y="runtime_minutes")
        plt.xticks(rotation=30, ha="right")
        plt.title("Training Runtime Comparison")
        plt.tight_layout()
        plt.show()
else:
    print("No training_summary.csv found yet.")