# SAM-VMNet Full Evaluation on ARCADE Test Set

Quantitative evaluation of the VMUNet model on the full 300-image test set.
Computes mIoU, Dice/F1, accuracy, sensitivity, specificity, and visualizes results.

Metrics computation matches `engine_branch1.py` exactly for reproducibility.

In [None]:
import sys, os
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

sys.path.insert(0, '.')

from utils import set_seed, BceDiceLoss
from configs.config_setting import setting_config
from models.vmunet.vmunet import VMUNet
from dataset import Branch1_datasets

set_seed(42)

# Auto-detect GPU and override config before model construction
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    setting_config.gpu_id = '0'
print(f"Using device: {device}")

## Load Model

Load the VMUNet checkpoint (adjust `CKPT_PATH` if the identification in `showcase_model.ipynb` showed a different mapping).

In [None]:
CKPT_PATH = './pre_trained_weights/best-epoch142-loss0.3230.pth'
DATA_PATH = './data/vessel/'

model_cfg = setting_config.model_config
model = VMUNet(
    num_classes=model_cfg['num_classes'],
    input_channels=model_cfg['input_channels'],
    depths=model_cfg['depths'],
    depths_decoder=model_cfg['depths_decoder'],
    drop_path_rate=model_cfg['drop_path_rate'],
    load_ckpt_path=model_cfg['load_ckpt_path'],
)
model.load_from()

checkpoint = torch.load(CKPT_PATH, map_location="cpu")
filtered_state_dict = {k: v for k, v in checkpoint.items()
                       if 'total_ops' not in k and 'total_params' not in k}
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()

print(f"Model loaded from {CKPT_PATH}")
print(f"Model device: {model.device}")

## Load Full Test Set

In [None]:
test_dataset = Branch1_datasets(DATA_PATH, setting_config, train=False, test=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(f"Test set: {len(test_dataset)} images")
print(f"Batch size: 1, Total batches: {len(test_loader)}")

## Run Full Evaluation

Iterate all test images, collect predictions and ground truths.
Store per-image predictions for later analysis.

In [None]:
threshold = setting_config.threshold  # 0.5
criterion = setting_config.criterion

all_preds = []   # per-image prediction arrays
all_gts = []     # per-image ground truth arrays
loss_list = []

model.eval()
with torch.no_grad():
    for i, data in enumerate(tqdm(test_loader, desc="Evaluating")):
        img, msk = data
        img = img.float()
        msk = msk.float()

        # Model moves input to its device internally
        out = model(img)
        # Compute loss on same device as msk
        loss = criterion(out.cpu(), msk)
        loss_list.append(loss.item())

        msk_np = msk.squeeze(1).cpu().detach().numpy()  # (1, H, W) -> (H, W)
        if type(out) is tuple:
            out = out[0]
        out_np = out.squeeze(1).cpu().detach().numpy()   # (1, H, W) -> (H, W)

        all_preds.append(out_np)
        all_gts.append(msk_np)

print(f"\nEvaluation complete. Mean loss: {np.mean(loss_list):.4f}")
print(f"Collected predictions for {len(all_preds)} images")

## Compute Aggregate Metrics

Replicate the exact metrics computation from `engine_branch1.py:144-157`.

In [None]:
# Flatten all predictions and ground truths (matches engine_branch1.py exactly)
preds_flat = np.array(all_preds).reshape(-1)
gts_flat = np.array(all_gts).reshape(-1)

y_pre = np.where(preds_flat >= threshold, 1, 0)
y_true = np.where(gts_flat >= 0.5, 1, 0)

confusion = confusion_matrix(y_true, y_pre)
TN, FP, FN, TP = confusion[0, 0], confusion[0, 1], confusion[1, 0], confusion[1, 1]

accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0

print("=" * 60)
print("AGGREGATE METRICS (Full Test Set)")
print("=" * 60)
print(f"  Loss:         {np.mean(loss_list):.4f}")
print(f"  mIoU:         {miou:.4f}")
print(f"  Dice/F1:      {f1_or_dsc:.4f}")
print(f"  Accuracy:     {accuracy:.4f}")
print(f"  Sensitivity:  {sensitivity:.4f}")
print(f"  Specificity:  {specificity:.4f}")
print(f"\nConfusion Matrix:")
print(f"  TN={TN:,}  FP={FP:,}")
print(f"  FN={FN:,}  TP={TP:,}")

## Per-Image Metrics Distribution

In [None]:
per_image_dice = []
per_image_miou = []

for i in range(len(all_preds)):
    pred_flat = all_preds[i].reshape(-1)
    gt_flat = all_gts[i].reshape(-1)

    y_p = np.where(pred_flat >= threshold, 1, 0)
    y_t = np.where(gt_flat >= 0.5, 1, 0)

    cm = confusion_matrix(y_t, y_p, labels=[0, 1])
    tn, fp, fn, tp = cm[0, 0], cm[0, 1], cm[1, 0], cm[1, 1]

    dice = float(2 * tp) / float(2 * tp + fp + fn) if float(2 * tp + fp + fn) != 0 else 1.0
    iou = float(tp) / float(tp + fp + fn) if float(tp + fp + fn) != 0 else 1.0

    per_image_dice.append(dice)
    per_image_miou.append(iou)

per_image_dice = np.array(per_image_dice)
per_image_miou = np.array(per_image_miou)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Dice histogram
axes[0].hist(per_image_dice, bins=30, color='steelblue', edgecolor='white', alpha=0.85)
axes[0].axvline(np.mean(per_image_dice), color='red', linestyle='--', linewidth=2,
                label=f'Mean={np.mean(per_image_dice):.4f}')
axes[0].axvline(np.median(per_image_dice), color='orange', linestyle='--', linewidth=2,
                label=f'Median={np.median(per_image_dice):.4f}')
axes[0].set_xlabel('Dice Score', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Per-Image Dice Score Distribution', fontsize=13)
axes[0].legend(fontsize=10)

# mIoU histogram
axes[1].hist(per_image_miou, bins=30, color='darkorange', edgecolor='white', alpha=0.85)
axes[1].axvline(np.mean(per_image_miou), color='red', linestyle='--', linewidth=2,
                label=f'Mean={np.mean(per_image_miou):.4f}')
axes[1].axvline(np.median(per_image_miou), color='blue', linestyle='--', linewidth=2,
                label=f'Median={np.median(per_image_miou):.4f}')
axes[1].set_xlabel('mIoU', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title('Per-Image mIoU Distribution', fontsize=13)
axes[1].legend(fontsize=10)

plt.tight_layout()
plt.show()

print(f"Per-image Dice: mean={np.mean(per_image_dice):.4f}, std={np.std(per_image_dice):.4f}, "
      f"min={np.min(per_image_dice):.4f}, max={np.max(per_image_dice):.4f}")
print(f"Per-image mIoU: mean={np.mean(per_image_miou):.4f}, std={np.std(per_image_miou):.4f}, "
      f"min={np.min(per_image_miou):.4f}, max={np.max(per_image_miou):.4f}")

## Confusion Matrix Heatmap

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

cm_display = np.array([[TN, FP], [FN, TP]])
cm_normalized = cm_display.astype(float) / cm_display.sum()

im = ax.imshow(cm_normalized, cmap='Blues', interpolation='nearest')
plt.colorbar(im, ax=ax, label='Proportion')

labels = ['Background (0)', 'Vessel (1)']
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(labels, fontsize=11)
ax.set_yticklabels(labels, fontsize=11)
ax.set_xlabel('Predicted', fontsize=13)
ax.set_ylabel('Actual', fontsize=13)
ax.set_title('Confusion Matrix (Normalized)', fontsize=14)

# Annotate cells with both count and percentage
for i in range(2):
    for j in range(2):
        count = cm_display[i, j]
        pct = cm_normalized[i, j] * 100
        color = 'white' if cm_normalized[i, j] > 0.5 else 'black'
        ax.text(j, i, f'{count:,}\n({pct:.2f}%)',
                ha='center', va='center', fontsize=12, color=color, fontweight='bold')

plt.tight_layout()
plt.show()

## Results Summary Table

In [None]:
print("\n" + "=" * 60)
print("         FINAL EVALUATION RESULTS")
print("=" * 60)
print(f"  Model:        VMUNet (Branch 1)")
print(f"  Checkpoint:   {os.path.basename(CKPT_PATH)}")
print(f"  Test images:  {len(test_dataset)}")
print(f"  Threshold:    {threshold}")
print("-" * 60)
print(f"  {'Metric':<20} {'Value':>10}")
print(f"  {'-'*20} {'-'*10}")
print(f"  {'Loss':<20} {np.mean(loss_list):>10.4f}")
print(f"  {'mIoU':<20} {miou:>10.4f}")
print(f"  {'Dice / F1':<20} {f1_or_dsc:>10.4f}")
print(f"  {'Accuracy':<20} {accuracy:>10.4f}")
print(f"  {'Sensitivity':<20} {sensitivity:>10.4f}")
print(f"  {'Specificity':<20} {specificity:>10.4f}")
print("-" * 60)
print(f"  {'Per-Image Dice':<20} {np.mean(per_image_dice):>10.4f} +/- {np.std(per_image_dice):.4f}")
print(f"  {'Per-Image mIoU':<20} {np.mean(per_image_miou):>10.4f} +/- {np.std(per_image_miou):.4f}")
print("=" * 60)

## Best and Worst Cases

Show the 3 best and 3 worst predictions by Dice score.

In [None]:
sorted_indices = np.argsort(per_image_dice)
worst_3 = sorted_indices[:3]
best_3 = sorted_indices[-3:][::-1]

def show_cases(indices, title):
    fig, axes = plt.subplots(len(indices), 3, figsize=(14, 4.5 * len(indices)))
    if len(indices) == 1:
        axes = axes[np.newaxis, :]

    for row, idx in enumerate(indices):
        img_path, msk_path = test_dataset.data[idx]
        raw_img = np.array(Image.open(img_path).convert('RGB'))
        raw_msk = np.array(Image.open(msk_path).convert('L'))

        gt_resized = np.array(Image.fromarray(raw_msk).resize((256, 256), Image.NEAREST))
        gt_binary = (gt_resized / 255.0 >= 0.5).astype(np.float32)
        pred_binary = np.where(all_preds[idx].squeeze() >= threshold, 1, 0).astype(np.float32)

        # Original
        axes[row, 0].imshow(raw_img)
        axes[row, 0].set_title(f'{os.path.basename(img_path)}', fontsize=10)
        axes[row, 0].axis('off')

        # Ground truth
        axes[row, 1].imshow(gt_binary, cmap='gray')
        axes[row, 1].set_title('Ground Truth', fontsize=10)
        axes[row, 1].axis('off')

        # Prediction
        axes[row, 2].imshow(pred_binary, cmap='gray')
        axes[row, 2].set_title(f'Prediction (Dice={per_image_dice[idx]:.4f})', fontsize=10)
        axes[row, 2].axis('off')

    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.0)
    plt.tight_layout()
    plt.show()

show_cases(best_3, 'Top 3 Best Predictions (Highest Dice)')
show_cases(worst_3, 'Top 3 Worst Predictions (Lowest Dice)')