# 02 - Classical Interpolation Baselines

This notebook runs classical interpolation baselines:
1. Nearest neighbor interpolation
2. Linear (bilinear) interpolation
3. Cubic interpolation

Evaluated on the test set with R=2 and R=3 sparse ratios.

In [2]:
# Mount Drive and setup
from google.colab import drive
drive.mount('/content/drive')

!pip install nibabel SimpleITK scikit-image PyYAML tqdm seaborn -q

import sys, os
PROJECT_ROOT = "/content/drive/MyDrive/TLCN"
sys.path.insert(0, PROJECT_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from pathlib import Path

from src.utils.config import load_config
from src.utils.seed import set_seed
from src.data.ct_org_loader import CTORGLoader
from src.data.sparse_simulator import SparseSimulator
from src.models.classical_interp import ClassicalInterpolator
from src.evaluation.metrics import evaluate_volume

# Load config
config = load_config(os.path.join(PROJECT_ROOT, "configs/default.yaml"))
set_seed(config["training"]["seed"])

# Output directory
OUTPUT_DIR = os.path.join(config["data"]["output_root"], "classical_baselines")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

Output directory: /content/drive/MyDrive/TLCN/outputs/classical_baselines


In [4]:
# Initialize data loader
loader = CTORGLoader(
    dataset_root=config["data"]["dataset_root"],
    hu_min=config["data"]["hu_min"],
    hu_max=config["data"]["hu_max"],
)

available_cases = loader.get_available_cases()
split = CTORGLoader.get_split(
    available_cases,
    config["data"]["test_cases"],
    config["data"]["val_cases"],
)

test_cases = split["test"]
print(f"Test cases ({len(test_cases)}): {test_cases}")

Test cases (21): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]


In [None]:
# Run classical baselines on test set
methods = config["classical"]["methods"]
sparse_ratios = config["data"]["sparse_ratios"]

# Results storage
all_results = {}

for R in sparse_ratios:
    print(f"\n{'='*60}")
    print(f"Sparse Ratio R={R}")
    print(f"{'='*60}")

    simulator = SparseSimulator(sparse_ratio=R)

    for method in methods:
        key = f"{method}_R{R}"
        all_results[key] = {}

        interp = ClassicalInterpolator(method=method)
        psnr_list = []
        ssim_list = []

        for case_idx in tqdm(test_cases, desc=f"{method} R={R}"):
            try:
                volume, labels, metadata = loader.load_and_preprocess(case_idx)
                sparse_data = simulator.simulate(volume)

                # Interpolate
                interpolated = interp.interpolate_volume(
                    sparse_data["observed_slices"],
                    sparse_data["observed_indices"],
                    sparse_data["target_indices"],
                )

                # Evaluate
                result = evaluate_volume(
                    interpolated,
                    sparse_data["target_slices"],
                    sparse_data["target_indices"],
                    labels=labels,
                )

                all_results[key][case_idx] = result["summary"]
                psnr_list.append(result["summary"]["mean_psnr"])
                ssim_list.append(result["summary"]["mean_ssim"])

            except Exception as e:
                print(f"  Error on case {case_idx}: {e}")

        avg_psnr = np.mean(psnr_list) if psnr_list else 0
        avg_ssim = np.mean(ssim_list) if ssim_list else 0
        print(f"  {method} R={R}: PSNR={avg_psnr:.2f} dB, SSIM={avg_ssim:.4f}")


Sparse Ratio R=2


nearest R=2: 100%|██████████| 21/21 [10:49<00:00, 30.92s/it]


  nearest R=2: PSNR=34.86 dB, SSIM=0.9400


linear R=2: 100%|██████████| 21/21 [15:23<00:00, 44.00s/it]


  linear R=2: PSNR=41.28 dB, SSIM=0.9708


cubic R=2:  10%|▉         | 2/21 [02:06<20:57, 66.20s/it]

In [5]:
# Create summary table
rows = []
for key, cases in all_results.items():
    method, ratio = key.rsplit("_R", 1)
    psnr_vals = [v["mean_psnr"] for v in cases.values()]
    ssim_vals = [v["mean_ssim"] for v in cases.values()]

    rows.append({
        "Method": method.capitalize(),
        "R": int(ratio),
        "PSNR (dB)": f"{np.mean(psnr_vals):.2f} +/- {np.std(psnr_vals):.2f}",
        "SSIM": f"{np.mean(ssim_vals):.4f} +/- {np.std(ssim_vals):.4f}",
        "Mean PSNR": np.mean(psnr_vals),
        "Mean SSIM": np.mean(ssim_vals),
    })

df = pd.DataFrame(rows)
print("\n" + "="*60)
print("Classical Baseline Results")
print("="*60)
print(df[["Method", "R", "PSNR (dB)", "SSIM"]].to_string(index=False))

NameError: name 'all_results' is not defined

In [6]:
# Save results
results_path = os.path.join(OUTPUT_DIR, "classical_results.json")

# Convert numpy types for JSON serialization
def convert_to_serializable(obj):
    if isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {str(k): convert_to_serializable(v) for k, v in obj.items()}
    return obj

with open(results_path, "w") as f:
    json.dump(convert_to_serializable(all_results), f, indent=2)

df.to_csv(os.path.join(OUTPUT_DIR, "classical_summary.csv"), index=False)
print(f"\nResults saved to {OUTPUT_DIR}")

NameError: name 'all_results' is not defined

In [None]:
# Visualize example results
from src.evaluation.visualization import plot_slice_comparison, plot_error_map
import matplotlib.pyplot as plt

# Pick a sample case and slice
sample_case = test_cases[0]
volume, labels, _ = loader.load_and_preprocess(sample_case)
R = 2
simulator = SparseSimulator(sparse_ratio=R)
sparse_data = simulator.simulate(volume)

# Get interpolation results for each method
results_dict = ClassicalInterpolator.interpolate_all_methods(
    sparse_data["observed_slices"],
    sparse_data["observed_indices"],
    sparse_data["target_indices"],
)

# Pick a target slice to visualize
target_idx = len(sparse_data["target_indices"]) // 2
z_val = sparse_data["target_indices"][target_idx]
gt_slice = sparse_data["target_slices"][:, :, target_idx]

predictions = {
    method: results_dict[method][:, :, target_idx]
    for method in methods
}

# Plot comparison
fig = plot_slice_comparison(
    gt_slice, predictions, z_idx=z_val,
    save_path=os.path.join(OUTPUT_DIR, f"comparison_case{sample_case}_z{z_val}.png"),
)
plt.show()

# Plot error map
fig = plot_error_map(
    gt_slice, predictions, z_idx=z_val,
    save_path=os.path.join(OUTPUT_DIR, f"error_case{sample_case}_z{z_val}.png"),
)
plt.show()

print("Classical baselines evaluation complete!")