# 06 - Visualization & Figure Generation

Generate all publication-ready figures:
1. Qualitative slice comparisons (GT vs all methods)
2. Error maps
3. Z-profile analysis
4. ROI-specific analysis
5. Training curve analysis for 3DGS

In [None]:
# Mount Drive and setup
from google.colab import drive
drive.mount('/content/drive')

!pip install nibabel SimpleITK scikit-image PyYAML tqdm seaborn -q

import sys, os
PROJECT_ROOT = "/content/drive/MyDrive/TLCN"
sys.path.insert(0, PROJECT_ROOT)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import json
from pathlib import Path

from src.utils.config import load_config
from src.utils.seed import set_seed
from src.data.ct_org_loader import CTORGLoader
from src.data.sparse_simulator import SparseSimulator
from src.models.classical_interp import ClassicalInterpolator
from src.evaluation.metrics import compute_psnr, compute_ssim, evaluate_volume
from src.evaluation.visualization import (
    plot_slice_comparison, plot_error_map,
    plot_training_curves, plot_z_error_profile
)

config = load_config(os.path.join(PROJECT_ROOT, "configs/default.yaml"))
set_seed(config["training"]["seed"])

OUTPUT_ROOT = config["data"]["output_root"]
FIG_DIR = os.path.join(OUTPUT_ROOT, "figures")
os.makedirs(FIG_DIR, exist_ok=True)

# Data loader
loader = CTORGLoader(
    dataset_root=config["data"]["dataset_root"],
    hu_min=config["data"]["hu_min"],
    hu_max=config["data"]["hu_max"],
)
available_cases = loader.get_available_cases()
split = CTORGLoader.get_split(
    available_cases,
    config["data"]["test_cases"],
    config["data"]["val_cases"],
)

print(f"Output directory: {FIG_DIR}")

In [None]:
# Helper: load 3DGS predictions for a case
def load_3dgs_predictions(case_idx, R, output_root):
    """Load trained 3DGS model and render predictions."""
    from src.training.trainer_3dgs import Trainer3DGS
    from src.models.gaussian_volume import GaussianVolume
    from src.models.slice_renderer import SliceRenderer
    
    ckpt_dir = os.path.join(output_root, "3dgs", f"R{R}", f"case_{case_idx}")
    ckpt_path = os.path.join(ckpt_dir, "final.pt")
    
    if not os.path.exists(ckpt_path):
        print(f"  Checkpoint not found: {ckpt_path}")
        return None
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint = torch.load(ckpt_path, map_location=device)
    
    volume_shape = checkpoint["volume_shape"]
    target_indices = checkpoint["target_indices"]
    
    # Reconstruct model
    n = checkpoint["positions"].shape[0]
    model = GaussianVolume(n, volume_shape, device)
    with torch.no_grad():
        model.positions.data = checkpoint["positions"].to(device)
        model.log_scales.data = checkpoint["log_scales"].to(device)
        model.raw_opacity.data = checkpoint["raw_opacity"].to(device)
        model.intensity.data = checkpoint["intensity"].to(device)
    
    # Render
    renderer = SliceRenderer(
        volume_shape[0], volume_shape[1],
        tile_size=16, z_threshold=3.0
    ).to(device)
    
    params = model.get_params()
    H, W = volume_shape[0], volume_shape[1]
    results = np.zeros((H, W, len(target_indices)), dtype=np.float32)
    
    with torch.no_grad():
        for i, z_idx in enumerate(target_indices):
            rendered = renderer(
                params["positions"], params["scales"],
                params["opacity"], params["intensity"],
                float(z_idx)
            )
            rendered = torch.clamp(rendered, 0.0, 1.0)
            results[:, :, i] = rendered.squeeze().cpu().numpy()
    
    del model, renderer
    torch.cuda.empty_cache()
    
    return results, target_indices

In [None]:
# Figure 1: Qualitative comparison on best/worst cases
R = 2
simulator = SparseSimulator(sparse_ratio=R)

# Pick 3 diverse test cases
viz_cases = split["test"][:3]

for case_idx in viz_cases:
    print(f"\nGenerating figures for case {case_idx}...")
    
    volume, labels, _ = loader.load_and_preprocess(case_idx)
    sparse_data = simulator.simulate(volume)
    
    # Classical interpolation
    classical_results = ClassicalInterpolator.interpolate_all_methods(
        sparse_data["observed_slices"],
        sparse_data["observed_indices"],
        sparse_data["target_indices"],
    )
    
    # Try loading 3DGS predictions
    gs_result = load_3dgs_predictions(case_idx, R, OUTPUT_ROOT)
    
    # Pick a mid-volume slice
    target_idx = len(sparse_data["target_indices"]) // 2
    z_val = sparse_data["target_indices"][target_idx]
    gt_slice = sparse_data["target_slices"][:, :, target_idx]
    
    predictions = {
        "Nearest": classical_results["nearest"][:, :, target_idx],
        "Linear": classical_results["linear"][:, :, target_idx],
        "Cubic": classical_results["cubic"][:, :, target_idx],
    }
    
    if gs_result is not None:
        gs_preds, gs_targets = gs_result
        # Find matching target index
        gs_target_list = list(gs_targets)
        if z_val in gs_target_list:
            gs_idx = gs_target_list.index(z_val)
            predictions["3DGS (Ours)"] = gs_preds[:, :, gs_idx]
    
    # Determine zoom region (center crop around organ)
    H, W = gt_slice.shape
    zoom = (H//4, 3*H//4, W//4, 3*W//4)
    
    # Plot comparison with zoom
    fig = plot_slice_comparison(
        gt_slice, predictions, z_idx=z_val,
        save_path=os.path.join(FIG_DIR, f"comparison_case{case_idx}_z{z_val}.png"),
        zoom_region=zoom,
    )
    plt.show()
    plt.close()
    
    # Plot error maps
    fig = plot_error_map(
        gt_slice, predictions, z_idx=z_val,
        save_path=os.path.join(FIG_DIR, f"error_case{case_idx}_z{z_val}.png"),
    )
    plt.show()
    plt.close()

In [None]:
# Figure 2: Z-axis error profile for 3DGS
print("\nGenerating z-axis error profiles...")

sample_case = split["test"][0]

for R in config["data"]["sparse_ratios"]:
    gs_result = load_3dgs_predictions(sample_case, R, OUTPUT_ROOT)
    
    if gs_result is not None:
        gs_preds, target_indices = gs_result
        volume, _, _ = loader.load_and_preprocess(sample_case)
        simulator = SparseSimulator(sparse_ratio=R)
        sparse_data = simulator.simulate(volume)
        
        eval_result = evaluate_volume(
            gs_preds,
            sparse_data["target_slices"],
            sparse_data["target_indices"],
        )
        
        fig = plot_z_error_profile(
            eval_result["per_slice"],
            method_name=f"3DGS (R={R})",
            save_path=os.path.join(FIG_DIR, f"z_profile_case{sample_case}_R{R}.png"),
        )
        plt.show()
        plt.close()

print("Z-axis profiles generated.")

In [None]:
# Figure 3: Training curves for a sample 3DGS run
print("\nLoading training history...")

sample_case = split["test"][0]
R = 2
history_path = os.path.join(OUTPUT_ROOT, "3dgs", f"R{R}", f"case_{sample_case}", "history.json")

if os.path.exists(history_path):
    with open(history_path) as f:
        history = json.load(f)
    
    fig = plot_training_curves(
        history,
        save_path=os.path.join(FIG_DIR, f"training_curves_case{sample_case}_R{R}.png"),
    )
    plt.show()
    plt.close()
    print("Training curves plotted.")
else:
    print(f"History not found: {history_path}")

In [None]:
# Figure 4: Organ-specific visualization (ROI analysis)
print("\nGenerating organ-specific analysis...")

sample_case = split["test"][0]
volume, labels, _ = loader.load_and_preprocess(sample_case)

if labels is not None:
    R = 2
    simulator = SparseSimulator(sparse_ratio=R)
    sparse_data = simulator.simulate(volume)
    
    # Pick a slice with multiple organs
    target_idx = len(sparse_data["target_indices"]) // 2
    z_val = sparse_data["target_indices"][target_idx]
    gt_slice = volume[:, :, z_val]
    label_slice = labels[:, :, z_val]
    
    # Create organ overlay
    organ_names = config["eval"]["organ_labels"]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # CT slice
    axes[0].imshow(gt_slice, cmap='gray', vmin=0, vmax=1)
    axes[0].set_title('CT Slice')
    axes[0].axis('off')
    
    # Label overlay
    axes[1].imshow(gt_slice, cmap='gray', vmin=0, vmax=1)
    masked = np.ma.masked_where(label_slice == 0, label_slice)
    axes[1].imshow(masked, cmap='tab10', alpha=0.4, vmin=0, vmax=6)
    axes[1].set_title('Organ Overlay')
    axes[1].axis('off')
    
    # Label map
    axes[2].imshow(label_slice, cmap='tab10', vmin=0, vmax=6)
    axes[2].set_title('Segmentation Labels')
    axes[2].axis('off')
    
    # Legend
    from matplotlib.patches import Patch
    legend_elements = []
    for name, val in organ_names.items():
        if val in np.unique(label_slice):
            color = plt.cm.tab10(val / 6)
            legend_elements.append(Patch(facecolor=color, label=f"{name} ({val})"))
    axes[2].legend(handles=legend_elements, loc='lower right', fontsize=8)
    
    plt.suptitle(f'Organ Segmentation - Volume {sample_case}, z={z_val}', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, f"organ_overlay_case{sample_case}.png"), dpi=150)
    plt.show()
else:
    print("No labels available for this case.")

In [None]:
# Summary: list all generated figures
print("\n" + "="*60)
print("Generated Figures")
print("="*60)

fig_files = sorted(Path(FIG_DIR).glob("*.png"))
for f in fig_files:
    size_kb = f.stat().st_size / 1024
    print(f"  {f.name} ({size_kb:.1f} KB)")

print(f"\nTotal: {len(fig_files)} figures saved to {FIG_DIR}")
print("\nAll visualizations complete!")