# SAM-VMNet Showcase: Visual Demo of Model Predictions

This notebook demonstrates the SAM-VMNet / VMUNet model on angiography vessel segmentation.
We identify which checkpoint belongs to which architecture, load the best model,
and visualize predictions on sample test images.

In [None]:
import sys, os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix

sys.path.insert(0, '.')

from utils import set_seed, BceDiceLoss
from configs.config_setting import setting_config
from models.vmunet.vmunet import VMUNet
from models.vmunet.samvmnet import SAMVMNet
from dataset import Branch1_datasets

set_seed(42)

# Auto-detect GPU and override config before model construction
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    setting_config.gpu_id = '0'
print(f"Using device: {device}")

## Step 1: Identify Checkpoints

Test-load each checkpoint into VMUNet and SAMVMNet with `strict=True` to determine which architecture each belongs to.

In [None]:
ckpt_paths = {
    'loss0.3230 (105MB)': './pre_trained_weights/best-epoch142-loss0.3230.pth',
    'loss0.3488 (112MB)': './pre_trained_weights/best-epoch142-loss0.3488.pth',
}

model_cfg = setting_config.model_config
results = {}

for ckpt_name, ckpt_path in ckpt_paths.items():
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    filtered = {k: v for k, v in checkpoint.items()
                if 'total_ops' not in k and 'total_params' not in k}
    results[ckpt_name] = []

    # Try VMUNet (Branch 1)
    try:
        model_test = VMUNet(
            num_classes=model_cfg['num_classes'],
            input_channels=model_cfg['input_channels'],
            depths=model_cfg['depths'],
            depths_decoder=model_cfg['depths_decoder'],
            drop_path_rate=model_cfg['drop_path_rate'],
            load_ckpt_path=None,
        )
        model_test.load_state_dict(filtered, strict=True)
        results[ckpt_name].append('VMUNet')
        print(f"  [OK] {ckpt_name} is compatible with VMUNet")
        del model_test
    except Exception as e:
        print(f"  [--] {ckpt_name} NOT compatible with VMUNet: {str(e)[:120]}")

    # Try SAMVMNet (Branch 2)
    try:
        model_test = SAMVMNet(
            num_classes=model_cfg['num_classes'],
            input_channels=model_cfg['input_channels'],
            depths=model_cfg['depths'],
            depths_decoder=model_cfg['depths_decoder'],
            drop_path_rate=model_cfg['drop_path_rate'],
            load_ckpt_path=None,
        )
        model_test.load_state_dict(filtered, strict=True)
        results[ckpt_name].append('SAMVMNet')
        print(f"  [OK] {ckpt_name} is compatible with SAMVMNet")
        del model_test
    except Exception as e:
        print(f"  [--] {ckpt_name} NOT compatible with SAMVMNet: {str(e)[:120]}")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n=== Checkpoint Identity Summary ===")
for ckpt_name, archs in results.items():
    arch_str = ', '.join(archs) if archs else 'UNKNOWN'
    print(f"  {ckpt_name}: {arch_str}")

## Step 2: Load the Best VMUNet Model

Load the identified Branch 1 (VMUNet) checkpoint. We use the lower-loss checkpoint (`loss0.3230`).
Adjust `CKPT_PATH` below if the identification step reveals a different mapping.

In [None]:
# --- Adjust this path based on the identification results above ---
CKPT_PATH = './pre_trained_weights/best-epoch142-loss0.3230.pth'

model_cfg = setting_config.model_config
model = VMUNet(
    num_classes=model_cfg['num_classes'],
    input_channels=model_cfg['input_channels'],
    depths=model_cfg['depths'],
    depths_decoder=model_cfg['depths_decoder'],
    drop_path_rate=model_cfg['drop_path_rate'],
    load_ckpt_path=model_cfg['load_ckpt_path'],
)
model.load_from()

checkpoint = torch.load(CKPT_PATH, map_location="cpu")
filtered_state_dict = {k: v for k, v in checkpoint.items()
                       if 'total_ops' not in k and 'total_params' not in k}
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()

print(f"Model loaded from {CKPT_PATH}")
print(f"Model device: {model.device}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")

## Step 3: Load Sample Test Images

Pick 8 evenly-spaced images from the test set to show diverse examples.

In [None]:
DATA_PATH = './data/vessel/'

test_dataset = Branch1_datasets(DATA_PATH, setting_config, train=False, test=True)
print(f"Test set size: {len(test_dataset)} images")

# Pick 8 evenly-spaced indices
num_samples = 8
sample_indices = np.linspace(0, len(test_dataset) - 1, num_samples, dtype=int)
print(f"Sample indices: {sample_indices}")

# Load transformed images (for inference) and raw images (for display)
sample_imgs = []      # transformed tensors
sample_msks = []      # transformed mask tensors
sample_raw_imgs = []  # raw PIL images for display
sample_raw_msks = []  # raw masks for display
sample_names = []

for idx in sample_indices:
    img_tensor, msk_tensor = test_dataset[idx]
    sample_imgs.append(img_tensor)
    sample_msks.append(msk_tensor)

    # Load raw for display
    img_path, msk_path = test_dataset.data[idx]
    sample_raw_imgs.append(np.array(Image.open(img_path).convert('RGB')))
    sample_raw_msks.append(np.array(Image.open(msk_path).convert('L')))
    sample_names.append(os.path.basename(img_path))

# Display the raw images in a grid
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(sample_raw_imgs[i])
    ax.set_title(sample_names[i], fontsize=10)
    ax.axis('off')
fig.suptitle('Selected Test Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 4: Run Inference

In [None]:
threshold = setting_config.threshold  # 0.5
predictions = []
pred_probs = []

model.eval()
with torch.no_grad():
    for img_tensor in sample_imgs:
        # Add batch dimension: (C, H, W) -> (1, C, H, W)
        x = img_tensor.unsqueeze(0).float()
        out = model(x)  # model moves input to its device internally
        out = out.squeeze().cpu().numpy()  # (H, W) probability map
        pred_probs.append(out)
        pred_binary = (out >= threshold).astype(np.float32)
        predictions.append(pred_binary)

print(f"Inference complete on {len(predictions)} images")
print(f"Prediction shape: {predictions[0].shape}")
print(f"Probability range: [{pred_probs[0].min():.4f}, {pred_probs[0].max():.4f}]")

## Step 5: 3-Panel Visualization (Original | Ground Truth | Prediction)

For each sample, show the original image, ground truth mask, and predicted mask side by side,
with per-image Dice score annotated.

In [None]:
def compute_dice(pred, gt):
    """Compute Dice score between binary prediction and ground truth."""
    pred_flat = pred.flatten()
    gt_flat = gt.flatten()
    intersection = np.sum(pred_flat * gt_flat)
    if (pred_flat.sum() + gt_flat.sum()) == 0:
        return 1.0
    return (2.0 * intersection) / (pred_flat.sum() + gt_flat.sum())


def compute_miou(pred, gt):
    """Compute mIoU between binary prediction and ground truth."""
    pred_flat = pred.flatten()
    gt_flat = gt.flatten()
    intersection = np.sum(pred_flat * gt_flat)
    union = pred_flat.sum() + gt_flat.sum() - intersection
    if union == 0:
        return 1.0
    return intersection / union


fig, axes = plt.subplots(num_samples, 3, figsize=(14, 4 * num_samples))

dice_scores = []
miou_scores = []

for i in range(num_samples):
    # Ground truth: resize raw mask to 256x256 to match prediction
    gt_resized = np.array(Image.fromarray(sample_raw_msks[i]).resize((256, 256), Image.NEAREST))
    gt_binary = (gt_resized / 255.0 >= 0.5).astype(np.float32)

    dice = compute_dice(predictions[i], gt_binary)
    miou = compute_miou(predictions[i], gt_binary)
    dice_scores.append(dice)
    miou_scores.append(miou)

    # Original image
    axes[i, 0].imshow(sample_raw_imgs[i])
    axes[i, 0].set_title(f'{sample_names[i]}', fontsize=10)
    axes[i, 0].axis('off')

    # Ground truth
    axes[i, 1].imshow(gt_binary, cmap='gray')
    axes[i, 1].set_title('Ground Truth', fontsize=10)
    axes[i, 1].axis('off')

    # Prediction
    axes[i, 2].imshow(predictions[i], cmap='gray')
    axes[i, 2].set_title(f'Prediction (Dice={dice:.4f})', fontsize=10)
    axes[i, 2].axis('off')

fig.suptitle('Model Predictions: Original | Ground Truth | Prediction', fontsize=14, fontweight='bold', y=1.0)
plt.tight_layout()
plt.show()

## Step 6: Overlay Visualization

Overlay predicted vessels on the original image:
- **Green**: True Positive (correctly predicted vessel)
- **Red**: False Positive (incorrectly predicted vessel)
- **Blue**: False Negative (missed vessel)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for i, ax in enumerate(axes.flat):
    # Resize raw image to 256x256
    raw_resized = np.array(Image.fromarray(sample_raw_imgs[i]).resize((256, 256), Image.BILINEAR))
    gt_resized = np.array(Image.fromarray(sample_raw_msks[i]).resize((256, 256), Image.NEAREST))
    gt_binary = (gt_resized / 255.0 >= 0.5).astype(np.float32)

    pred = predictions[i]

    # Create overlay
    overlay = raw_resized.copy().astype(np.float32)
    alpha = 0.45

    # True Positive -> green
    tp_mask = (pred == 1) & (gt_binary == 1)
    overlay[tp_mask] = overlay[tp_mask] * (1 - alpha) + np.array([0, 255, 0]) * alpha

    # False Positive -> red
    fp_mask = (pred == 1) & (gt_binary == 0)
    overlay[fp_mask] = overlay[fp_mask] * (1 - alpha) + np.array([255, 0, 0]) * alpha

    # False Negative -> blue
    fn_mask = (pred == 0) & (gt_binary == 1)
    overlay[fn_mask] = overlay[fn_mask] * (1 - alpha) + np.array([0, 100, 255]) * alpha

    ax.imshow(overlay.astype(np.uint8))
    ax.set_title(f'{sample_names[i]}\nDice={dice_scores[i]:.4f}', fontsize=10)
    ax.axis('off')

fig.suptitle('Overlay: Green=TP, Red=FP, Blue=FN', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 7: Quick Summary

In [None]:
print("=" * 50)
print("Quick Summary (on 8 sample images)")
print("=" * 50)
print(f"  Mean Dice:  {np.mean(dice_scores):.4f} +/- {np.std(dice_scores):.4f}")
print(f"  Mean mIoU:  {np.mean(miou_scores):.4f} +/- {np.std(miou_scores):.4f}")
print()
print("Per-image breakdown:")
print(f"  {'Image':<15} {'Dice':>8} {'mIoU':>8}")
print(f"  {'-'*15} {'-'*8} {'-'*8}")
for i in range(num_samples):
    print(f"  {sample_names[i]:<15} {dice_scores[i]:>8.4f} {miou_scores[i]:>8.4f}")
print(f"  {'-'*15} {'-'*8} {'-'*8}")
print(f"  {'Mean':<15} {np.mean(dice_scores):>8.4f} {np.mean(miou_scores):>8.4f}")