## YOLOv11 â†’ SAM2 pipeline for fire segmentation (Single Positive Point prompting)
Uses YOLOv11 for fire detection and SAM2 with a single positive point (centroid of each bounding box) as prompt.  
Saves:
- detection+point visualization
- masks
- overlays  
Includes optional metric computation (IoU, Dice, MAE, pixel accuracy) against ground-truth masks.


In [None]:
import os
import glob
from PIL import Image
import numpy as np
import cv2
import torch
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


FRAME_DIR    = ""   # Directory containing input frames
YOLO_MODEL   = ""   # Path to YOLO model weights
SAM2_CFG     = ""   # Path to SAM2 config
SAM2_WEIGHTS = ""   # Path to SAM2 weights
OUT_DIR      = ""   # Output directory for results
IMG_SIZE     = 960
CONF_THRESH  = 0.3

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

yolo_model = YOLO(YOLO_MODEL).to(DEVICE)

def load_sam(cfg, ckpt, device):
    model = build_sam2(cfg, ckpt, device=device)
    return SAM2ImagePredictor(model)

sam_predictor = load_sam(SAM2_CFG, SAM2_WEIGHTS, DEVICE)

mask_dir = os.path.join(OUT_DIR, 'masks')
overlay_dir = os.path.join(OUT_DIR, 'overlays')
det_dir = os.path.join(OUT_DIR, 'detected_fires')
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(overlay_dir, exist_ok=True)
os.makedirs(det_dir, exist_ok=True)

frame_paths = sorted(
    sum([glob.glob(os.path.join(FRAME_DIR, ext)) for ext in ("*.jpg", "*.png")], [])
)

for idx, path in enumerate(frame_paths):
    img_pil = Image.open(path).convert('RGB')
    img_np  = np.array(img_pil)

    results = yolo_model.predict([img_pil], imgsz=IMG_SIZE, conf=CONF_THRESH)
    boxes_xyxy = results[0].boxes.xyxy.cpu().numpy()
    confs = results[0].boxes.conf.cpu().numpy()

    det_vis = img_np.copy()
    for box, conf in zip(boxes_xyxy, confs):
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        # Confidence text to LEFT of box
        conf_text = f"{conf:.2f}"
        cv2.putText(det_vis, conf_text, (x1 - 50, y1 + 15), cv2.FONT_HERSHEY_SIMPLEX,
                    0.6, (0, 255, 0), 2)

    sam_predictor.set_image(img_np)
    final_mask = np.zeros(img_np.shape[:2], dtype=np.uint8)

    for box in boxes_xyxy:
        # centroid of bounding box
        x1, y1, x2, y2 = map(int, box)
        cx = int((x1 + x2) / 2)
        cy = int((y1 + y2) / 2)

        point_coords = np.array([[cx, cy]])
        point_labels = np.array([1])  # positive point

        masks, scores, _ = sam_predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=False,
        )
        mask = masks[0].astype(np.uint8)
        final_mask = np.maximum(final_mask, mask)

        # Visualize SP point on detection image
        cv2.circle(det_vis, (cx, cy), 5, (0, 0, 255), -1)

    cv2.imwrite(os.path.join(det_dir, f"{idx:05d}.png"),
                cv2.cvtColor(det_vis, cv2.COLOR_RGB2BGR))

    mask_filename = f"{idx:05d}.png"
    cv2.imwrite(os.path.join(mask_dir, mask_filename), final_mask * 255)

    overlay = img_np.copy()
    red = np.zeros_like(img_np); red[:] = (255, 0, 0)
    alpha = 0.4
    overlay = np.where(final_mask[..., None] == 1,
                       cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0),
                       overlay)
    cv2.imwrite(os.path.join(overlay_dir, mask_filename),
                cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))

    print(f"[{idx+1}/{len(frame_paths)}] Processed: {os.path.basename(path)}")

print("All frames processed.")

import numpy as np

GT_MASK_DIR   = ""  # Ground truth masks
PRED_MASK_DIR = ""  # SAM2-predicted masks

def compute_metrics(gt, pred):
    gt = (gt > 127).astype(np.uint8)
    pred = (pred > 127).astype(np.uint8)
    
    intersection = np.logical_and(gt, pred).sum()
    union = np.logical_or(gt, pred).sum()
    iou = intersection / union if union > 0 else 1.0
    dice = (2 * intersection) / (gt.sum() + pred.sum() + 1e-6)
    mae = np.abs(gt - pred).mean()
    pixel_acc = (gt == pred).sum() / gt.size
    return iou, dice, mae, pixel_acc

gt_paths = sorted(glob.glob(os.path.join(GT_MASK_DIR, "*.png")))
pred_paths = sorted(glob.glob(os.path.join(PRED_MASK_DIR, "*.png")))

ious, dices, maes, accs = [], [], [], []
for gt_path, pred_path in zip(gt_paths, pred_paths):
    gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
    pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
    if gt.shape != pred.shape:
        pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST)
    iou, dice, mae, acc = compute_metrics(gt, pred)
    ious.append(iou); dices.append(dice); maes.append(mae); accs.append(acc)

print(f"Mean IoU : {np.mean(ious):.4f}")
print(f"Mean Dice: {np.mean(dices):.4f}")
print(f"Mean MAE : {np.mean(maes):.4f}")
print(f"Pixel Acc: {np.mean(accs):.4f}")
