In [None]:
!pip install git+https://github.com/facebookresearch/sam2.git
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt

In [None]:
import os
import cv2
import torch
import numpy as np
import shutil
from tqdm import tqdm
from PIL import Image

# Import Stage 1 Segformer for the initial prompt
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

# Import SAM 2 Video Predictor
from sam2.build_sam import build_sam2_video_predictor

# ================= CONFIGURATION =================
# Input / Output
INPUT_VIDEO = "/kaggle/input/datasets/gonoszgonosz/rat-test-video/test.mp4"
OUTPUT_VIDEO = "/kaggle/working/SAM2_Tracked_Output.mp4"
TEMP_FRAME_DIR = "/kaggle/working/sam2_temp_frames"

# Stage 1 Segformer (Auto-Prompter)
STAGE1_PATH = "/kaggle/input/datasets/gonoszgonosz/b2-1024-weights/final_rat_model_b3_1024"

# SAM 2 Settings
SAM2_CHECKPOINT = "sam2_hiera_small.pt"
MODEL_CFG = "sam2_hiera_s.yaml"

device = "cuda" if torch.cuda.is_available() else "cpu"
# =================================================

def extract_frames(video_path, output_dir):
    """Extracts all frames from a video as JPEGs for SAM 2 to read."""
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret: break
        # SAM 2 expects sequentially numbered JPEGs
        frame_name = f"{frame_idx:05d}.jpg"
        cv2.imwrite(os.path.join(output_dir, frame_name), frame)
        frame_idx += 1
        
    cap.release()
    return fps, w, h, frame_idx

def get_stage1_bbox(image_path, model, processor):
    """Uses Stage 1 Segformer to find the rat in the first frame."""
    image = cv2.imread(image_path)
    h, w = image.shape[:2]
    rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    inputs = processor(images=Image.fromarray(rgb_img), return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = torch.nn.functional.interpolate(
            outputs.logits, size=(h, w), mode="bilinear", align_corners=False
        )
        probs = torch.nn.functional.softmax(logits, dim=1)
        mask = (probs[0, 1, :, :] > 0.5).cpu().numpy().astype(np.uint8)
        
    coords = np.column_stack(np.where(mask > 0))
    if coords.size == 0:
        return None
        
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    return np.array([x_min, y_min, x_max, y_max], dtype=np.float32)

def apply_overlay(image, mask, color=(0, 0, 255), alpha=0.5):
    """Blends a solid color over the masked region."""
    overlay = np.full_like(image, color)
    blended = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
    res = image.copy()
    res[mask > 0] = blended[mask > 0]
    return res

def main():
    print("--- EXTRACTING FRAMES ---")
    fps, w, h, total_frames = extract_frames(INPUT_VIDEO, TEMP_FRAME_DIR)
    print(f"Extracted {total_frames} frames to {TEMP_FRAME_DIR}")
    
    print("--- LOADING MODELS ---")
    # Load Stage 1
    proc1 = SegformerImageProcessor.from_pretrained(STAGE1_PATH)
    model1 = SegformerForSemanticSegmentation.from_pretrained(STAGE1_PATH).to(device)
    model1.eval()
    
    # Load SAM 2 Video Predictor
    predictor = build_sam2_video_predictor(MODEL_CFG, SAM2_CHECKPOINT, device=device)
    
    print("--- INITIALIZING SAM 2 MEMORY BANK ---")
    # SAM 2 scans the directory to build its internal temporal tensors
    inference_state = predictor.init_state(video_path=TEMP_FRAME_DIR)
    
    print("--- GENERATING INITIAL PROMPT VIA STAGE 1 ---")
    first_frame_path = os.path.join(TEMP_FRAME_DIR, "00000.jpg")
    init_box = get_stage1_bbox(first_frame_path, model1, proc1)
    
    if init_box is None:
        print("CRITICAL ERROR: Stage 1 could not find the rat in Frame 0. Cannot prompt SAM 2.")
        return
        
    print(f"Bounding Box Found: {init_box}")
    
    # Send the bounding box to SAM 2 for Frame 0
    # obj_id=1 represents the rat entity it needs to track
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=1,
        box=init_box
    )
    
    print("--- PROPAGATING THROUGH VIDEO ---")
    # Setup Video Writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_video = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (w, h))
    
    # SAM 2 yields the masks frame by frame as it calculates them
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        frame_path = os.path.join(TEMP_FRAME_DIR, f"{out_frame_idx:05d}.jpg")
        frame = cv2.imread(frame_path)
        
        # Extract the binary mask (logits > 0.0)
        mask = (out_mask_logits[0] > 0.0).cpu().numpy().squeeze().astype(np.uint8)
        
        # Draw Output
        res_frame = apply_overlay(frame, mask, color=(0, 255, 0)) # Green mask for SAM 2
        
        # Add labels
        cv2.putText(res_frame, f"SAM 2 Temporal Tracking", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        if out_frame_idx == 0:
            # Draw the Stage 1 bounding box on the first frame so you can see the prompt
            x1, y1, x2, y2 = map(int, init_box)
            cv2.rectangle(res_frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(res_frame, "Stage 1 Auto-Prompt Box", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
            
        out_video.write(res_frame)
        
        # Basic progress output
        if out_frame_idx % 50 == 0:
            print(f"Processed frame {out_frame_idx}/{total_frames}")

    out_video.release()
    print(f"--- DONE. SAVED TO {OUTPUT_VIDEO} ---")
    
    # Cleanup temporary frames
    shutil.rmtree(TEMP_FRAME_DIR)

if __name__ == "__main__":
    main()