## YOLOv11 + SAM2 Video Segmentation with Grid Sampling & HSV Filtering
Processes video frames with YOLOv11 for detection and SAM2 for segmentation using multiple grid-sampled points filtered by HSV fire-color heuristic.  
Saves:
- detection frames  
- masks  
- overlays  
- grid points before/after filtering  
Outputs final annotated video and reports model size, FPS, inference time, and GPU memory usage.


In [None]:
import os
import cv2
import time
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


VIDEO_PATH   = ""   # Directory containing input video
YOLO_MODEL   = ""   # Path to YOLO model weights
SAM2_CFG     = ""   # Path to SAM2 base_plus config
SAM2_WEIGHTS = ""   # Path to SAM2 base_plus weights
OUT_DIR      = ""   # Output directory for results
IMG_SIZE     = 960
CONF_THRESH  = 0.3
GRID_SIZE    = 3   # number of grid points 3x3
DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mask_dir = os.path.join(OUT_DIR, 'masks')
overlay_dir = os.path.join(OUT_DIR, 'overlays')
det_dir = os.path.join(OUT_DIR, 'detected_fires')
prompt_before_dir = os.path.join(OUT_DIR, 'prompts_before_filter')
prompt_after_dir = os.path.join(OUT_DIR, 'prompts_after_filter')
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(overlay_dir, exist_ok=True)
os.makedirs(det_dir, exist_ok=True)
os.makedirs(prompt_before_dir, exist_ok=True)
os.makedirs(prompt_after_dir, exist_ok=True)

yolo_model = YOLO(YOLO_MODEL).to(DEVICE)
def load_sam(cfg, ckpt, device):
    model = build_sam2(cfg, ckpt, device=device)
    return SAM2ImagePredictor(model)
sam_predictor = load_sam(SAM2_CFG, SAM2_WEIGHTS, DEVICE)

def grid_points_from_box(x1, y1, x2, y2, grid_size):
    xs = np.linspace(x1, x2, grid_size+2, dtype=int)[1:-1]
    ys = np.linspace(y1, y2, grid_size+2, dtype=int)[1:-1]
    return [[int(x), int(y)] for x in xs for y in ys]

def is_fire_pixel(h, s, v):
    return (0 <= h <= 60) and (s >= 100) and (v >= 150)

def filter_fire_points(img, points):
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    fire_points = []
    for (x, y) in points:
        if x < 0 or y < 0 or x >= img.shape[1] or y >= img.shape[0]:
            continue
        h, s, v = hsv[y, x]
        if is_fire_pixel(h, s, v):
            fire_points.append([x, y])
    if len(fire_points) == 0:
        cx = int(np.mean([p[0] for p in points]))
        cy = int(np.mean([p[1] for p in points]))
        fire_points.append([cx, cy])
    return fire_points

def visualize_points(img, points, path, color=(0,255,0)):
    vis = img.copy()
    for (x, y) in points:
        cv2.circle(vis, (x, y), 4, color, -1)
    cv2.imwrite(path, cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))

param_count_yolo = sum(p.numel() for p in yolo_model.model.parameters())
param_count_sam2 = sum(p.numel() for p in sam_predictor.model.parameters())
model_size_mb = (param_count_yolo + param_count_sam2) * 4 / (1024**2)
print(f"Model Size: {model_size_mb:.2f} MB")
print(f"YOLO Parameters: {param_count_yolo:,}")
print(f"SAM2 Parameters: {param_count_sam2:,}")

cap = cv2.VideoCapture(VIDEO_PATH)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
fps = cap.get(cv2.CAP_PROP_FPS)
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_video_path = os.path.join(OUT_DIR, "final_output.avi")
out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (frame_w, frame_h))

frame_count = 0
total_time = 0.0
torch.cuda.reset_peak_memory_stats()

while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_count += 1
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_pil = Image.fromarray(img_rgb)

    start_time = time.time()

    results = yolo_model.predict([img_pil], imgsz=IMG_SIZE, conf=CONF_THRESH, verbose=False)
    boxes_xyxy = results[0].boxes.xyxy.cpu().numpy()
    confs = results[0].boxes.conf.cpu().numpy()

    det_vis = img_rgb.copy()
    for box, conf in zip(boxes_xyxy, confs):
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(det_vis, f"{conf:.2f}", (x1 - 50, y1 + 15),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    cv2.imwrite(os.path.join(det_dir, f"{frame_count:05d}.png"),
                cv2.cvtColor(det_vis, cv2.COLOR_RGB2BGR))

    sam_predictor.set_image(img_rgb)
    final_mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)

    for box in boxes_xyxy:
        x1, y1, x2, y2 = map(int, box)
        grid_points = grid_points_from_box(x1, y1, x2, y2, GRID_SIZE)
        visualize_points(img_rgb, grid_points,
                         os.path.join(prompt_before_dir, f"{frame_count:05d}.png"),
                         color=(0,0,255))
        filtered_points = filter_fire_points(img_rgb, grid_points)
        visualize_points(img_rgb, filtered_points,
                         os.path.join(prompt_after_dir, f"{frame_count:05d}.png"),
                         color=(0,255,0))
        point_coords = np.array(filtered_points)
        point_labels = np.ones(len(filtered_points))
        masks, scores, _ = sam_predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            box=box[None, :],
            multimask_output=False,
        )
        mask = masks[0].astype(np.uint8)
        final_mask = np.maximum(final_mask, mask)

    cv2.imwrite(os.path.join(mask_dir, f"{frame_count:05d}.png"), final_mask * 255)
    overlay = img_rgb.copy()
    red = np.zeros_like(img_rgb); red[:] = (255, 0, 0)
    alpha = 0.4
    overlay = np.where(final_mask[..., None] == 1,
                       cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0),
                       overlay)
    overlay_bgr = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(overlay_dir, f"{frame_count:05d}.png"), overlay_bgr)

    out_writer.write(overlay_bgr)

    elapsed = time.time() - start_time
    total_time += elapsed
    print(f"[{frame_count}] Frame processed in {elapsed*1000:.2f} ms")

cap.release()
out_writer.release()

avg_time_ms = (total_time / frame_count) * 1000
fps_actual = frame_count / total_time
peak_mem_mb = torch.cuda.max_memory_allocated() / (1024**2)

print(f"\n--- Computational Efficiency ---")
print(f"Average Inference Time: {avg_time_ms:.2f} ms/frame")
print(f"FPS: {fps_actual:.2f}")
print(f"Peak GPU Memory: {peak_mem_mb:.2f} MB")
print(f"Model Size: {model_size_mb:.2f} MB")
print(f"Output video saved to: {out_video_path}")
print(f"All detection images, masks, overlays, and prompt visualizations saved to: {OUT_DIR}")
