In [None]:
!pip install git+https://github.com/facebookresearch/sam2.git
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt

In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.spatial.distance import directed_hausdorff

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# ================= CONFIGURATION =================
IMG_DIR = "/kaggle/input/datasets/gonoszgonosz/rodent-data-2/processed/images"
MASK_DIR = "/kaggle/input/datasets/gonoszgonosz/rodent-data-2/processed/masks"

# SAM 2 Specifics
SAM2_CHECKPOINT = "sam2_hiera_small.pt"
MODEL_CFG = "sam2_hiera_s.yaml"
OUTPUT_CSV = "/kaggle/working/sam2_comprehensive_metrics.csv"

# Polygon annotation tolerance
BOUNDARY_DILATION = 7 

device = "cuda" if torch.cuda.is_available() else "cpu"
# =================================================

def calculate_iou(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    if union == 0: return 1.0 if intersection == 0 else 0.0
    return intersection / union

def calculate_dice(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    total = pred_mask.sum() + gt_mask.sum()
    if total == 0: return 1.0
    return (2. * intersection) / total

def calculate_boundary_iou(pred_mask, gt_mask, dilation=7):
    kernel = np.ones((dilation, dilation), dtype=np.uint8)
    gt_boundary = cv2.morphologyEx(gt_mask.astype(np.uint8), cv2.MORPH_GRADIENT, kernel) > 0
    pred_boundary = cv2.morphologyEx(pred_mask.astype(np.uint8), cv2.MORPH_GRADIENT, kernel) > 0
    
    intersection = np.logical_and(pred_boundary, gt_boundary).sum()
    union = np.logical_or(pred_boundary, gt_boundary).sum()
    
    if union == 0: return 1.0 if intersection == 0 else 0.0
    return intersection / union

def calculate_hausdorff(pred_mask, gt_mask):
    pred_edges = cv2.Canny((pred_mask.astype(np.uint8) * 255), 0, 1)
    gt_edges = cv2.Canny((gt_mask.astype(np.uint8) * 255), 0, 1)
    
    pred_pts = np.argwhere(pred_edges > 0)
    gt_pts = np.argwhere(gt_edges > 0)
    
    if len(pred_pts) == 0 or len(gt_pts) == 0:
        return np.nan 
        
    d1 = directed_hausdorff(pred_pts, gt_pts)[0]
    d2 = directed_hausdorff(gt_pts, pred_pts)[0]
    return max(d1, d2)

def get_gt_bbox(mask):
    coords = np.column_stack(np.where(mask > 0))
    if coords.size == 0: return None
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    return np.array([x_min, y_min, x_max, y_max])

def main():
    print("--- LOADING SAM 2 MODEL ---")
    sam2_model = build_sam2(MODEL_CFG, SAM2_CHECKPOINT, device=device)
    predictor = SAM2ImagePredictor(sam2_model)

    all_images = [f for f in os.listdir(IMG_DIR) if f.endswith(('.jpg', '.png'))]
    all_masks = [f for f in os.listdir(MASK_DIR) if f.endswith(('.jpg', '.png'))]
    
    img_map = {os.path.splitext(f)[0]: f for f in all_images}
    mask_map = {os.path.splitext(f)[0]: f for f in all_masks}
    common_ids = sorted(list(set(img_map.keys()) & set(mask_map.keys())))

    evaluation_data = []
    
    print(f"--- STARTING ZERO-SHOT EVALUATION ON {len(common_ids)} FRAMES ---")
    pbar = tqdm(common_ids)
    
    for cid in pbar:
        img_path = os.path.join(IMG_DIR, img_map[cid])
        mask_path = os.path.join(MASK_DIR, mask_map[cid])
        
        image = cv2.imread(img_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        gt_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        gt_binary = np.where(gt_mask > 0, 1, 0).astype(bool)
        
        input_box = get_gt_bbox(gt_binary)
        if input_box is None: continue
            
        # 1. Provide image to SAM 2 Predictor
        predictor.set_image(image_rgb)
        
        # 2. Prompt SAM 2 with Oracle Bounding Box
        masks, scores, logits = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=True, 
        )
        
        # 3. Extract SAM 2's highest confidence mask
        best_mask_idx = np.argmax(scores)
        sam_pred_binary = masks[best_mask_idx]
        
        # 4. Calculate metrics
        iou = calculate_iou(sam_pred_binary, gt_binary)
        dice = calculate_dice(sam_pred_binary, gt_binary)
        bound_iou = calculate_boundary_iou(sam_pred_binary, gt_binary, BOUNDARY_DILATION)
        hausdorff = calculate_hausdorff(sam_pred_binary, gt_binary)
        
        evaluation_data.append({
            "Frame_ID": cid,
            "mIoU": iou,
            "Dice": dice,
            "Boundary_IoU": bound_iou,
            "Hausdorff_Distance": hausdorff
        })
        
        pbar.set_postfix({"Avg B-IoU": f"{np.mean([x['Boundary_IoU'] for x in evaluation_data]):.4f}"})

    df = pd.DataFrame(evaluation_data)
    df.to_csv(OUTPUT_CSV, index=False)

    print("\n" + "="*50)
    print(" SAM 2 ZERO-SHOT EVALUATION COMPLETED")
    print("="*50)
    print(f" Total Frames Evaluated : {len(df)}")
    print(f" Mean mIoU              : {df['mIoU'].mean():.4f}")
    print(f" Mean Dice Score        : {df['Dice'].mean():.4f}")
    print(f" Mean Boundary IoU      : {df['Boundary_IoU'].mean():.4f}")
    print(f" Mean Hausdorff Distance: {df['Hausdorff_Distance'].mean():.2f} pixels")
    print(f" Data saved to          : {OUTPUT_CSV}")
    print("="*50)

if __name__ == "__main__":
    main()