# SAM-VMNet Evaluation on DCA1 Test Set

Quantitative evaluation of the VMUNet model trained on DCA1 dataset.
Computes mIoU, Dice/F1, accuracy, sensitivity, specificity, and visualizes results.

Also includes cross-dataset evaluation:
- DCA1-trained model on ARCADE test set (generalization check)
- ARCADE-trained model on DCA1 test set (baseline comparison)

In [None]:
import sys, os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

sys.path.insert(0, 'SAM_VMNet')

from utils import set_seed, BceDiceLoss
from models.vmunet.vmunet import VMUNet
from dataset import Branch1_datasets

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Configuration

Update these paths to point to your trained checkpoints.

In [None]:
# DCA1-trained checkpoint
DCA1_CKPT_PATH = './pre_trained_weights/dca1-best-epoch103-loss0.2863.pth'
# ARCADE-trained checkpoint (baseline)
ARCADE_CKPT_PATH = './pre_trained_weights/best-epoch142-loss0.3230.pth'

DCA1_DATA_PATH = './data/dca1/'
ARCADE_DATA_PATH = './data/vessel/'

THRESHOLD = 0.5

## Helper Functions

In [None]:
def load_model(ckpt_path, gpu_id='0'):
    """Load VMUNet with given checkpoint."""
    from configs.config_setting import setting_config
    model_cfg = setting_config.model_config
    model = VMUNet(
        num_classes=model_cfg['num_classes'],
        input_channels=model_cfg['input_channels'],
        depths=model_cfg['depths'],
        depths_decoder=model_cfg['depths_decoder'],
        drop_path_rate=model_cfg['drop_path_rate'],
        load_ckpt_path=model_cfg['load_ckpt_path'],
        gpu_id=gpu_id,
    )
    model.load_from()

    checkpoint = torch.load(ckpt_path, map_location='cpu')
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    filtered = {k: v for k, v in state_dict.items()
                if 'total_ops' not in k and 'total_params' not in k}
    model.load_state_dict(filtered, strict=False)
    model.eval()
    print(f"Loaded checkpoint: {ckpt_path}")
    return model


def evaluate_model(model, test_loader, criterion, threshold=0.5):
    """Run evaluation and return per-image predictions, GTs, and losses."""
    all_preds, all_gts, loss_list = [], [], []

    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader, desc="Evaluating"):
            img, msk = data
            img, msk = img.float(), msk.float()
            out = model(img)
            loss = criterion(out.cpu(), msk)
            loss_list.append(loss.item())

            msk_np = msk.squeeze(1).cpu().detach().numpy()
            if type(out) is tuple:
                out = out[0]
            out_np = out.squeeze(1).cpu().detach().numpy()
            all_preds.append(out_np)
            all_gts.append(msk_np)

    return all_preds, all_gts, loss_list


def compute_metrics(all_preds, all_gts, loss_list, threshold=0.5):
    """Compute aggregate and per-image metrics."""
    preds_flat = np.array(all_preds).reshape(-1)
    gts_flat = np.array(all_gts).reshape(-1)

    y_pre = np.where(preds_flat >= threshold, 1, 0)
    y_true = np.where(gts_flat >= 0.5, 1, 0)

    confusion = confusion_matrix(y_true, y_pre)
    TN, FP, FN, TP = confusion[0, 0], confusion[0, 1], confusion[1, 0], confusion[1, 1]

    accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
    sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
    specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
    f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
    miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0

    # Per-image metrics
    per_image_dice, per_image_miou = [], []
    for i in range(len(all_preds)):
        pred_flat = all_preds[i].reshape(-1)
        gt_flat = all_gts[i].reshape(-1)
        y_p = np.where(pred_flat >= threshold, 1, 0)
        y_t = np.where(gt_flat >= 0.5, 1, 0)
        cm = confusion_matrix(y_t, y_p, labels=[0, 1])
        tn, fp, fn, tp = cm[0, 0], cm[0, 1], cm[1, 0], cm[1, 1]
        dice = float(2 * tp) / float(2 * tp + fp + fn) if float(2 * tp + fp + fn) != 0 else 1.0
        iou = float(tp) / float(tp + fp + fn) if float(tp + fp + fn) != 0 else 1.0
        per_image_dice.append(dice)
        per_image_miou.append(iou)

    return {
        'loss': np.mean(loss_list),
        'miou': miou, 'dice': f1_or_dsc,
        'accuracy': accuracy, 'sensitivity': sensitivity, 'specificity': specificity,
        'TN': TN, 'FP': FP, 'FN': FN, 'TP': TP,
        'per_image_dice': np.array(per_image_dice),
        'per_image_miou': np.array(per_image_miou),
    }


def print_metrics(metrics, title):
    """Print formatted metrics summary."""
    print("=" * 60)
    print(f"  {title}")
    print("=" * 60)
    print(f"  Loss:         {metrics['loss']:.4f}")
    print(f"  mIoU:         {metrics['miou']:.4f}")
    print(f"  Dice/F1:      {metrics['dice']:.4f}")
    print(f"  Accuracy:     {metrics['accuracy']:.4f}")
    print(f"  Sensitivity:  {metrics['sensitivity']:.4f}")
    print(f"  Specificity:  {metrics['specificity']:.4f}")
    print(f"  Per-Image Dice: {np.mean(metrics['per_image_dice']):.4f} +/- {np.std(metrics['per_image_dice']):.4f}")
    print(f"  Per-Image mIoU: {np.mean(metrics['per_image_miou']):.4f} +/- {np.std(metrics['per_image_miou']):.4f}")
    print("=" * 60)

## 1. Evaluate DCA1-Trained Model on DCA1 Test Set

In [None]:
from configs.config_setting_dca1 import setting_config as dca1_config

if torch.cuda.is_available():
    dca1_config.gpu_id = '0'

# Load DCA1 test set
dca1_test_dataset = Branch1_datasets(DCA1_DATA_PATH, dca1_config, train=False, test=True)
dca1_test_loader = DataLoader(dca1_test_dataset, batch_size=1, shuffle=False)
print(f"DCA1 test set: {len(dca1_test_dataset)} images")

criterion = BceDiceLoss(wb=1, wd=1)

In [None]:
# Load DCA1-trained model
dca1_model = load_model(DCA1_CKPT_PATH)

# Evaluate
dca1_preds, dca1_gts, dca1_losses = evaluate_model(dca1_model, dca1_test_loader, criterion)
dca1_metrics = compute_metrics(dca1_preds, dca1_gts, dca1_losses)
print_metrics(dca1_metrics, 'DCA1-Trained Model on DCA1 Test Set')

## 2. Per-Image Metrics Distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Dice histogram
axes[0].hist(dca1_metrics['per_image_dice'], bins=15, color='steelblue', edgecolor='white', alpha=0.85)
axes[0].axvline(np.mean(dca1_metrics['per_image_dice']), color='red', linestyle='--', linewidth=2,
                label=f"Mean={np.mean(dca1_metrics['per_image_dice']):.4f}")
axes[0].axvline(np.median(dca1_metrics['per_image_dice']), color='orange', linestyle='--', linewidth=2,
                label=f"Median={np.median(dca1_metrics['per_image_dice']):.4f}")
axes[0].set_xlabel('Dice Score', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Per-Image Dice Score Distribution (DCA1)', fontsize=13)
axes[0].legend(fontsize=10)

# mIoU histogram
axes[1].hist(dca1_metrics['per_image_miou'], bins=15, color='darkorange', edgecolor='white', alpha=0.85)
axes[1].axvline(np.mean(dca1_metrics['per_image_miou']), color='red', linestyle='--', linewidth=2,
                label=f"Mean={np.mean(dca1_metrics['per_image_miou']):.4f}")
axes[1].axvline(np.median(dca1_metrics['per_image_miou']), color='blue', linestyle='--', linewidth=2,
                label=f"Median={np.median(dca1_metrics['per_image_miou']):.4f}")
axes[1].set_xlabel('mIoU', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title('Per-Image mIoU Distribution (DCA1)', fontsize=13)
axes[1].legend(fontsize=10)

plt.tight_layout()
plt.show()

## 3. Confusion Matrix

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

cm_display = np.array([[dca1_metrics['TN'], dca1_metrics['FP']],
                       [dca1_metrics['FN'], dca1_metrics['TP']]])
cm_normalized = cm_display.astype(float) / cm_display.sum()

im = ax.imshow(cm_normalized, cmap='Blues', interpolation='nearest')
plt.colorbar(im, ax=ax, label='Proportion')

labels = ['Background (0)', 'Vessel (1)']
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(labels, fontsize=11)
ax.set_yticklabels(labels, fontsize=11)
ax.set_xlabel('Predicted', fontsize=13)
ax.set_ylabel('Actual', fontsize=13)
ax.set_title('Confusion Matrix â€” DCA1-Trained on DCA1 Test', fontsize=14)

for i in range(2):
    for j in range(2):
        count = cm_display[i, j]
        pct = cm_normalized[i, j] * 100
        color = 'white' if cm_normalized[i, j] > 0.5 else 'black'
        ax.text(j, i, f'{count:,}\n({pct:.2f}%)',
                ha='center', va='center', fontsize=12, color=color, fontweight='bold')

plt.tight_layout()
plt.show()

## 4. Best and Worst Cases

In [None]:
sorted_indices = np.argsort(dca1_metrics['per_image_dice'])
worst_3 = sorted_indices[:3]
best_3 = sorted_indices[-3:][::-1]

def show_cases(indices, title, test_dataset, all_preds, per_image_dice, threshold=0.5):
    fig, axes = plt.subplots(len(indices), 3, figsize=(14, 4.5 * len(indices)))
    if len(indices) == 1:
        axes = axes[np.newaxis, :]

    for row, idx in enumerate(indices):
        img_path, msk_path = test_dataset.data[idx]
        raw_img = np.array(Image.open(img_path).convert('RGB'))
        raw_msk = np.array(Image.open(msk_path).convert('L'))

        gt_resized = np.array(Image.fromarray(raw_msk).resize((256, 256), Image.NEAREST))
        gt_binary = (gt_resized / 255.0 >= 0.5).astype(np.float32)
        pred_binary = np.where(all_preds[idx].squeeze() >= threshold, 1, 0).astype(np.float32)

        axes[row, 0].imshow(raw_img)
        axes[row, 0].set_title(f'{os.path.basename(img_path)}', fontsize=10)
        axes[row, 0].axis('off')

        axes[row, 1].imshow(gt_binary, cmap='gray')
        axes[row, 1].set_title('Ground Truth', fontsize=10)
        axes[row, 1].axis('off')

        axes[row, 2].imshow(pred_binary, cmap='gray')
        axes[row, 2].set_title(f'Prediction (Dice={per_image_dice[idx]:.4f})', fontsize=10)
        axes[row, 2].axis('off')

    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.0)
    plt.tight_layout()
    plt.show()

show_cases(best_3, 'Top 3 Best Predictions (DCA1)', dca1_test_dataset, dca1_preds, dca1_metrics['per_image_dice'])
show_cases(worst_3, 'Top 3 Worst Predictions (DCA1)', dca1_test_dataset, dca1_preds, dca1_metrics['per_image_dice'])

---
## 5. Cross-Dataset Evaluation

### 5a. ARCADE-Trained Model on DCA1 Test Set (Baseline)

In [None]:
arcade_model = load_model(ARCADE_CKPT_PATH)

arcade_on_dca1_preds, arcade_on_dca1_gts, arcade_on_dca1_losses = evaluate_model(
    arcade_model, dca1_test_loader, criterion)
arcade_on_dca1_metrics = compute_metrics(arcade_on_dca1_preds, arcade_on_dca1_gts, arcade_on_dca1_losses)
print_metrics(arcade_on_dca1_metrics, 'ARCADE-Trained Model on DCA1 Test Set (Baseline)')

### 5b. DCA1-Trained Model on ARCADE Test Set (Generalization Check)

In [None]:
from configs.config_setting import setting_config as arcade_config

if torch.cuda.is_available():
    arcade_config.gpu_id = '0'

# Load ARCADE test set
if os.path.exists(ARCADE_DATA_PATH):
    arcade_test_dataset = Branch1_datasets(ARCADE_DATA_PATH, arcade_config, train=False, test=True)
    arcade_test_loader = DataLoader(arcade_test_dataset, batch_size=1, shuffle=False)
    print(f"ARCADE test set: {len(arcade_test_dataset)} images")

    dca1_on_arcade_preds, dca1_on_arcade_gts, dca1_on_arcade_losses = evaluate_model(
        dca1_model, arcade_test_loader, criterion)
    dca1_on_arcade_metrics = compute_metrics(dca1_on_arcade_preds, dca1_on_arcade_gts, dca1_on_arcade_losses)
    print_metrics(dca1_on_arcade_metrics, 'DCA1-Trained Model on ARCADE Test Set')
else:
    print(f"ARCADE data not found at {ARCADE_DATA_PATH}, skipping cross-dataset eval.")

### 5c. Comparison Summary

In [None]:
print("\n" + "=" * 70)
print("  CROSS-DATASET COMPARISON SUMMARY")
print("=" * 70)
print(f"  {'Experiment':<45} {'Dice':>8} {'mIoU':>8}")
print(f"  {'-'*45} {'-'*8} {'-'*8}")
print(f"  {'DCA1-trained on DCA1 test':<45} {dca1_metrics['dice']:>8.4f} {dca1_metrics['miou']:>8.4f}")
print(f"  {'ARCADE-trained on DCA1 test (baseline)':<45} {arcade_on_dca1_metrics['dice']:>8.4f} {arcade_on_dca1_metrics['miou']:>8.4f}")
if os.path.exists(ARCADE_DATA_PATH):
    print(f"  {'DCA1-trained on ARCADE test (generalization)':<45} {dca1_on_arcade_metrics['dice']:>8.4f} {dca1_on_arcade_metrics['miou']:>8.4f}")
print("=" * 70)