In [None]:
# ============================================================
# Comparative Training Curves (Paper Style)
# SRCNN vs SRResNet vs EDSR
# ============================================================

import json
from pathlib import Path
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# Base directory (notebooks/said/)
# ------------------------------------------------------------
BASE_DIR = Path.cwd()

# ------------------------------------------------------------
# Paths to JSON training histories
# ------------------------------------------------------------
PATHS = {
    "SRCNN": BASE_DIR / "../../src/models/checkpoints/SRCNN/srcnn_training_history.json",
    "SRResNet": BASE_DIR / "../../src/models/checkpoints/SRRESNET/srresnet.json",
    "EDSR": BASE_DIR / "../../src/models/checkpoints/EDSR/edsr_training_history.json",
}

# ------------------------------------------------------------
# Load history
# ------------------------------------------------------------
def load_history(path):
    if not path.exists():
        raise FileNotFoundError(f"❌ File not found: {path}")
    with open(path, "r") as f:
        return json.load(f)

srcnn = load_history(PATHS["SRCNN"])
srresnet = load_history(PATHS["SRResNet"])
edsr = load_history(PATHS["EDSR"])

# ------------------------------------------------------------
# Extract metrics (based on your exact JSON structure)
# ------------------------------------------------------------
srcnn_val_loss = srcnn["val_losses"]
srresnet_val_loss = srresnet["val_losses"]
edsr_val_loss = edsr["val_losses"]

srcnn_val_psnr = srcnn["val_psnrs"]
srresnet_val_psnr = srresnet["val_psnrs"]
edsr_val_psnr = edsr["val_psnrs"]

# ------------------------------------------------------------
# Epoch axes
# ------------------------------------------------------------
epochs_srcnn = range(1, len(srcnn_val_loss) + 1)
epochs_srresnet = range(1, len(srresnet_val_loss) + 1)
epochs_edsr = range(1, len(edsr_val_loss) + 1)

# ------------------------------------------------------------
# Colors (publication style)
# ------------------------------------------------------------
COLORS = {
    "SRCNN": "#1f77b4",     # blue
    "SRResNet": "#ff7f0e",  # orange
    "EDSR": "#2ca02c",      # green
}

# ------------------------------------------------------------
# Create figure
# ------------------------------------------------------------
plt.figure(figsize=(13, 5))

# =======================
# Validation Loss
# =======================
plt.subplot(1, 2, 1)

plt.plot(epochs_srcnn, srcnn_val_loss,
         label="SRCNN", color=COLORS["SRCNN"], linewidth=2)

plt.plot(epochs_srresnet, srresnet_val_loss,
         label="SRResNet", color=COLORS["SRResNet"], linewidth=2)

plt.plot(epochs_edsr, edsr_val_loss,
         label="EDSR", color=COLORS["EDSR"], linewidth=2)

plt.xlabel("Epoch")
plt.ylabel("Validation Loss (L1)")
plt.title("Validation Loss Comparison")
plt.grid(True, linestyle="--", alpha=0.5)
plt.legend()

# =======================
# Validation PSNR
# =======================
plt.subplot(1, 2, 2)

plt.plot(epochs_srcnn, srcnn_val_psnr,
         label="SRCNN", color=COLORS["SRCNN"], linewidth=2)

plt.plot(epochs_srresnet, srresnet_val_psnr,
         label="SRResNet", color=COLORS["SRResNet"], linewidth=2)

plt.plot(epochs_edsr, edsr_val_psnr,
         label="EDSR", color=COLORS["EDSR"], linewidth=2)

plt.xlabel("Epoch")
plt.ylabel("PSNR (dB)")
plt.title("Validation PSNR Comparison")
plt.grid(True, linestyle="--", alpha=0.5)
plt.legend()

# ------------------------------------------------------------
# Save & show
# ------------------------------------------------------------
plt.tight_layout()
plt.savefig("comparison_SRCNN_SRResNet_EDSR.png", dpi=300)
plt.show()

print("✅ Figure successfully generated: comparison_SRCNN_SRResNet_EDSR.png")


KeyError: 'val_loss'