# Ground Truth Metric Validation
This notebook evaluates pairwise metrics between ground truth volumes and visualizes both pairwise plots and a full heatmap.

In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree
import mrcfile
from metrics_utils import load_mrc, chamfer distance, volumetric_iou

In [None]:
# CONFIGURATION
gt_path = '/path/to/ground_truth_volumes'
metric_fn = chamfer_distance  # or volumetric_iou
metric_name = "Chamfer Distance"
threshold = 0.05
grid_size = 128
voxel_size = 4.5

In [None]:
gt_files = sorted(glob.glob(os.path.join(gt_path, '*.mrc')))
num_volumes = len(gt_files)
reference_indices = list(range(0, num_volumes, 10))
all_metric_values = []

# Determine global y-range
for ref_idx in reference_indices:
    gt_ref = load_mrc(gt_files[ref_idx])
    metrics = [metric_fn(gt_ref, load_mrc(gt), threshold, grid_size, voxel_size) for gt in gt_files]
    all_metric_values.extend(metrics)

y_min, y_max = min(all_metric_values), max(all_metric_values)

# (a) Plot pairwise comparisons
fig, axs = plt.subplots(3, 4, figsize=(15, 10))
axs = axs.flatten()
for i, ref_idx in enumerate(reference_indices):
    gt_ref = load_mrc(gt_files[ref_idx])
    metric_values = [metric_fn(gt_ref, load_mrc(gt), threshold, grid_size, voxel_size) for gt in gt_files]
    ax = axs[i]
    ax.plot(range(num_volumes), metric_values, marker='o', linestyle='None', markersize=3)
    ax.set_title(f"Reference #{ref_idx+1}")
    ax.set_xlabel("Volume Index")
    ax.set_ylabel(metric_name)
    ax.set_ylim(y_min - 0.1, y_max + 0.1)

plt.tight_layout()
plt.show()

In [None]:
# (b) Heatmap across all GT volumes
heatmap_data = np.zeros((num_volumes, num_volumes))
for i in range(num_volumes):
    vol1 = load_mrc(gt_files[i])
    for j in range(num_volumes):
        vol2 = load_mrc(gt_files[j])
        heatmap_data[i, j] = metric_fn(vol1, vol2, threshold, grid_size, voxel_size)

plt.figure(figsize=(8, 8))
plt.imshow(heatmap_data, cmap='magma', origin='lower')
plt.colorbar(label=metric_name)
plt.xlabel("Volume Index")
plt.ylabel("Volume Index")
plt.title(f"Heatmap of {metric_name}")
plt.show()