# 05 - Benchmark & Ablation Studies

This notebook:
1. Aggregates results from all methods (classical, U-Net, 3DGS)
2. Creates comparison tables and statistical analysis
3. Runs ablation studies (3DGS with/without regularization)
4. Generates publication-ready tables

In [None]:
# Mount Drive and setup
from google.colab import drive
drive.mount('/content/drive')

!pip install nibabel SimpleITK scikit-image PyYAML tqdm seaborn -q

import sys, os
PROJECT_ROOT = "/content/drive/MyDrive/TLCN"
sys.path.insert(0, PROJECT_ROOT)

In [None]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns

from src.utils.config import load_config

config = load_config(os.path.join(PROJECT_ROOT, "configs/default.yaml"))
OUTPUT_ROOT = config["data"]["output_root"]

# Load all results
results = {}

# Classical baselines
classical_path = os.path.join(OUTPUT_ROOT, "classical_baselines/classical_results.json")
if os.path.exists(classical_path):
    with open(classical_path) as f:
        results.update(json.load(f))
    print("Loaded classical baseline results")

# U-Net
unet_path = os.path.join(OUTPUT_ROOT, "unet_baseline/unet_results.json")
if os.path.exists(unet_path):
    with open(unet_path) as f:
        results.update(json.load(f))
    print("Loaded U-Net results")

# 3DGS
gs_path = os.path.join(OUTPUT_ROOT, "3dgs/3dgs_results.json")
if os.path.exists(gs_path):
    with open(gs_path) as f:
        results.update(json.load(f))
    print("Loaded 3DGS results")

# 3DGS without regularization (ablation)
gs_noreg_path = os.path.join(OUTPUT_ROOT, "3dgs_noreg/3dgs_noreg_results.json")
if os.path.exists(gs_noreg_path):
    with open(gs_noreg_path) as f:
        results.update(json.load(f))
    print("Loaded 3DGS (no reg) results")

print(f"\nTotal methods loaded: {len(results)}")
print(f"Methods: {list(results.keys())}")

In [None]:
# Create summary table
rows = []

for key, cases in results.items():
    # Parse method name and R
    parts = key.rsplit("_R", 1)
    method = parts[0].replace("_", " ").title()
    R = int(parts[1]) if len(parts) > 1 else 2

    psnr_vals = []
    ssim_vals = []
    for case_id, metrics in cases.items():
        if isinstance(metrics, dict) and "mean_psnr" in metrics:
            psnr_vals.append(metrics["mean_psnr"])
            ssim_vals.append(metrics["mean_ssim"])

    if psnr_vals:
        rows.append({
            "Method": method,
            "R": R,
            "PSNR (mean)": np.mean(psnr_vals),
            "PSNR (std)": np.std(psnr_vals),
            "SSIM (mean)": np.mean(ssim_vals),
            "SSIM (std)": np.std(ssim_vals),
            "N": len(psnr_vals),
        })

df = pd.DataFrame(rows)
df = df.sort_values(["R", "PSNR (mean)"], ascending=[True, False])

# Format for display
df["PSNR"] = df.apply(
    lambda r: f"{r['PSNR (mean)']:.2f} +/- {r['PSNR (std)']:.2f}", axis=1
)
df["SSIM"] = df.apply(
    lambda r: f"{r['SSIM (mean)']:.4f} +/- {r['SSIM (std)']:.4f}", axis=1
)

print("\n" + "="*70)
print("BENCHMARK RESULTS")
print("="*70)
print(df[["Method", "R", "PSNR", "SSIM", "N"]].to_string(index=False))

In [None]:
# Generate LaTeX table for paper
print("\n" + "="*70)
print("LaTeX Table")
print("="*70)

latex = df[["Method", "R", "PSNR", "SSIM"]].to_latex(
    index=False,
    caption="Quantitative comparison of slice interpolation methods on CT-ORG test set.",
    label="tab:results",
    column_format="lrcc",
)
print(latex)

# Save LaTeX table
with open(os.path.join(OUTPUT_ROOT, "results_table.tex"), "w") as f:
    f.write(latex)

In [None]:
# Visualization: PSNR comparison bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for i, R in enumerate(config["data"]["sparse_ratios"]):
    df_r = df[df["R"] == R].sort_values("PSNR (mean)")

    colors = sns.color_palette("husl", len(df_r))
    bars = axes[i].barh(
        df_r["Method"], df_r["PSNR (mean)"],
        xerr=df_r["PSNR (std)"],
        color=colors, capsize=3
    )
    axes[i].set_xlabel("PSNR (dB)")
    axes[i].set_title(f"R = {R}")
    axes[i].grid(True, alpha=0.3, axis='x')

    # Add value labels
    for bar, val in zip(bars, df_r["PSNR (mean)"]):
        axes[i].text(
            bar.get_width() + 0.1, bar.get_y() + bar.get_height() / 2,
            f"{val:.2f}", va='center', fontsize=9
        )

plt.suptitle("PSNR Comparison Across Methods", fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_ROOT, "psnr_comparison.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# SSIM comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for i, R in enumerate(config["data"]["sparse_ratios"]):
    df_r = df[df["R"] == R].sort_values("SSIM (mean)")

    colors = sns.color_palette("husl", len(df_r))
    bars = axes[i].barh(
        df_r["Method"], df_r["SSIM (mean)"],
        xerr=df_r["SSIM (std)"],
        color=colors, capsize=3
    )
    axes[i].set_xlabel("SSIM")
    axes[i].set_title(f"R = {R}")
    axes[i].grid(True, alpha=0.3, axis='x')

    for bar, val in zip(bars, df_r["SSIM (mean)"]):
        axes[i].text(
            bar.get_width() + 0.001, bar.get_y() + bar.get_height() / 2,
            f"{val:.4f}", va='center', fontsize=9
        )

plt.suptitle("SSIM Comparison Across Methods", fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_ROOT, "ssim_comparison.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Per-case distribution analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for i, R in enumerate(config["data"]["sparse_ratios"]):
    data_for_plot = []
    labels_for_plot = []

    for key, cases in results.items():
        parts = key.rsplit("_R", 1)
        method = parts[0].replace("_", " ").title()
        r_val = int(parts[1]) if len(parts) > 1 else 2

        if r_val != R:
            continue

        for case_id, metrics in cases.items():
            if isinstance(metrics, dict) and "mean_psnr" in metrics:
                data_for_plot.append(metrics["mean_psnr"])
                labels_for_plot.append(method)

    if data_for_plot:
        plot_df = pd.DataFrame({"Method": labels_for_plot, "PSNR": data_for_plot})
        sns.boxplot(x="Method", y="PSNR", data=plot_df, ax=axes[i])
        sns.stripplot(
            x="Method", y="PSNR", data=plot_df, ax=axes[i],
            color="black", alpha=0.4, size=4
        )
        axes[i].set_title(f"R = {R}")
        axes[i].set_ylabel("PSNR (dB)")
        axes[i].tick_params(axis='x', rotation=45)
        axes[i].grid(True, alpha=0.3, axis='y')

plt.suptitle("Per-Case PSNR Distribution", fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_ROOT, "psnr_distribution.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Ablation: Run 3DGS without regularization
# Skip if already done
import torch
from src.data.ct_org_loader import CTORGLoader
from src.data.sparse_simulator import SparseSimulator
from src.training.trainer_3dgs import Trainer3DGS
from src.utils.seed import set_seed
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(config["training"]["seed"])

ABLATION_DIR = os.path.join(OUTPUT_ROOT, "3dgs_noreg")
os.makedirs(ABLATION_DIR, exist_ok=True)

# Config without regularization
noreg_config = config.copy()
noreg_config["loss"] = {"lambda_smooth": 0.0, "lambda_edge": 0.0}

loader = CTORGLoader(
    dataset_root=config["data"]["dataset_root"],
    hu_min=config["data"]["hu_min"],
    hu_max=config["data"]["hu_max"],
)

ablation_results = {}
R = 2  # Run ablation only on R=2
simulator = SparseSimulator(sparse_ratio=R)

# Use a subset for ablation
ablation_cases = split["test"][:5]

for case_idx in ablation_cases:
    print(f"\nAblation: case {case_idx} (no reg)...")
    try:
        volume, labels, _ = loader.load_and_preprocess(case_idx)
        sparse_data = simulator.simulate(volume)

        trainer = Trainer3DGS(
            volume=volume,
            observed_indices=sparse_data["observed_indices"],
            target_indices=sparse_data["target_indices"],
            config=noreg_config,
            labels=labels,
            device=device,
            checkpoint_dir=os.path.join(ABLATION_DIR, f"case_{case_idx}"),
        )
        trainer.train()
        eval_result = trainer.evaluate_on_targets()
        ablation_results[case_idx] = eval_result["summary"]

        print(f"  PSNR: {eval_result['summary']['mean_psnr']:.2f} dB")

        del trainer
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"  Error: {e}")

# Save ablation results
def convert_to_serializable(obj):
    if isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {str(k): convert_to_serializable(v) for k, v in obj.items()}
    return obj

with open(os.path.join(ABLATION_DIR, "3dgs_noreg_results.json"), "w") as f:
    json.dump({f"3dgs Noreg_R{R}": convert_to_serializable(ablation_results)}, f, indent=2)

print("\nAblation study complete!")

In [None]:
# Ablation summary
print("\n" + "="*60)
print("ABLATION: Effect of Regularization (R=2)")
print("="*60)

# Compare with vs without regularization
ablation_comparison = []
gs_key = f"3dgs_R{R}"

if gs_key in results:
    with_reg_psnr = [v["mean_psnr"] for v in results[gs_key].values() if isinstance(v, dict)]
    with_reg_ssim = [v["mean_ssim"] for v in results[gs_key].values() if isinstance(v, dict)]
    ablation_comparison.append({
        "Config": "3DGS + Reg",
        "PSNR": f"{np.mean(with_reg_psnr):.2f} +/- {np.std(with_reg_psnr):.2f}",
        "SSIM": f"{np.mean(with_reg_ssim):.4f} +/- {np.std(with_reg_ssim):.4f}",
    })

if ablation_results:
    noreg_psnr = [v["mean_psnr"] for v in ablation_results.values() if isinstance(v, dict)]
    noreg_ssim = [v["mean_ssim"] for v in ablation_results.values() if isinstance(v, dict)]
    ablation_comparison.append({
        "Config": "3DGS (no Reg)",
        "PSNR": f"{np.mean(noreg_psnr):.2f} +/- {np.std(noreg_psnr):.2f}",
        "SSIM": f"{np.mean(noreg_ssim):.4f} +/- {np.std(noreg_ssim):.4f}",
    })

if ablation_comparison:
    ablation_df = pd.DataFrame(ablation_comparison)
    print(ablation_df.to_string(index=False))

# Save full benchmark summary
df.to_csv(os.path.join(OUTPUT_ROOT, "full_benchmark.csv"), index=False)
print(f"\nAll results saved to {OUTPUT_ROOT}")