In [None]:
import numpy as np
import cv2
from tqdm import tqdm
import torch
import os
import sys

# Ensure segment_anything is installed in Colab (after torch)
!pip install -q segment-anything

# Import the necessary components
from segment_anything import sam_model_registry, SamPredictor

# --- CONFIGURATION ---
VIDEO_FILENAME = 'sam2Demo.mp4'
VIDEO_PATH = f'./{VIDEO_FILENAME}'
OUTPUT_NPZ_PATH = 'segmentation.npz'

SAM_CHECKPOINT_PATH = 'sam_vit_h_4b8939.pth'
MODEL_TYPE = 'vit_h'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download the checkpoint if necessary
if os.path.exists(SAM_CHECKPOINT_PATH):
    print(f"Removing existing SAM checkpoint '{SAM_CHECKPOINT_PATH}' for a fresh download...")
    os.remove(SAM_CHECKPOINT_PATH)
    print("Previous checkpoint removed.")

print(f"Downloading SAM checkpoint '{SAM_CHECKPOINT_PATH}'...")
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
print("Download complete.")

# --- CRITICAL ACTION REQUIRED: UPDATE THESE COORDINATES! ---
# You MUST change these to the precise center of the object in Frame 0
# Example: Use coordinates like [[350, 400]]
INITIAL_POINT_PROMPT = np.array([[320, 420]]) 
INITIAL_POINT_LABEL = np.array([1]) # Foreground label

# SPEED OPTIMIZATION
FRAME_SKIP_COUNT = 3

# FILE SIZE REDUCTION FIX: Downscale the mask significantly
# CHANGED DOWNSCALE_FACTOR to 8 (64x reduction)
DOWNSCALE_FACTOR = 8 

# --- 1. Initialization ---
try:
    print(f"Initializing SAM model on device: {DEVICE}...")
    sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH)
    sam.to(device=DEVICE)
    predictor = SamPredictor(sam)
except Exception as e:
    print(f"Initialization Error: {e}. Check GPU/File download.")
    sys.exit()

# Initialize video capture
cap = cv2.VideoCapture(VIDEO_PATH)
if not cap.isOpened():
    print(f"CRITICAL ERROR: Could not open video file at {VIDEO_PATH}. Ensure the name is correct.")
    sys.exit()

frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

print(f"Video opened: {frame_width}x{frame_height} @ {total_frames} frames. Ready to segment.")

# --- 2. Process Video and Generate Masks (Targeted Tracking) ---
all_masks = []
current_bbox = None

print("Generating masks frame by frame with tracking...")

for i in tqdm(range(total_frames), desc="Segmenting Frames"):
    ret, frame = cap.read()
    if not ret: break

    current_mask_to_save = np.zeros((frame_height, frame_width), dtype=np.uint8)
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # --- Segmentation Logic (Process only key frames) ---
    if i == 0 or (i > 0 and i % FRAME_SKIP_COUNT == 0):

        predictor.set_image(rgb_frame)

        # Determine the prompt: Point on frame 0, or BBOX from last frame
        if i == 0:
            input_points = INITIAL_POINT_PROMPT
            input_labels = INITIAL_POINT_LABEL
            box_prompt = None
        elif current_bbox is not None:
            input_points = None
            input_labels = None
            box_prompt = current_bbox
        else:
            # If tracking was lost, retry with the initial point
            input_points = INITIAL_POINT_PROMPT
            input_labels = INITIAL_POINT_LABEL
            box_prompt = None

        # Run SAM prediction
        masks, scores, _ = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            box=box_prompt,
            multimask_output=False,
        )

        # Process the result
        if scores[0] > 0.8: # Only accept high-confidence masks
            current_mask_to_save = masks[0].astype(np.uint8) * 1

            # Calculate the BBOX of the new mask for the next key frame's prompt
            y_indices, x_indices = np.where(current_mask_to_save)
            if len(x_indices) > 0:
                # The BBOX must be calculated on the full-res mask before downscaling
                current_bbox = np.array([
                    np.min(x_indices), np.min(y_indices),
                    np.max(x_indices), np.max(y_indices)
                ])
            else:
                current_bbox = None
        else:
            current_bbox = None

    else:
        current_mask_to_save = all_masks[-1] if all_masks else np.zeros((frame_height, frame_width), dtype=np.uint8)

    # --- FILE SIZE REDUCTION HAPPENS HERE ---
    if current_mask_to_save.any() and DOWNSCALE_FACTOR > 1:
        # Resize to 1/8 width and 1/8 height (1/64th the size)
        current_mask_to_save = cv2.resize(
            current_mask_to_save, 
            (frame_width // DOWNSCALE_FACTOR, frame_height // DOWNSCALE_FACTOR), 
            interpolation=cv2.INTER_NEAREST
        )

    all_masks.append(current_mask_to_save)

cap.release()

# --- 3. Save the Data ---
if all_masks:
    final_masks_array = np.stack(all_masks, axis=0)
    print(f"\nSegmentation complete. Final array shape: {final_masks_array.shape}")

    # Save the array with the key 'masks'
    np.savez_compressed(OUTPUT_NPZ_PATH, masks=final_masks_array)

    print(f"SUCCESS: {OUTPUT_NPZ_PATH} created. This file is now small enough for GitHub. DOWNLOAD AND REPLACE.")
else:
    print("ERROR: No frames processed. Check video path.")

Downloading SAM checkpoint 'sam_vit_h_4b8939.pth'...
Download complete.
Initializing SAM model on device: cpu...
Video opened: 1920x1080 @ 72 frames. Ready to segment.
Generating masks frame by frame with tracking...


Segmenting Frames: 100%|██████████| 72/72 [55:06<00:00, 45.93s/it]



Segmentation complete. Final array shape: (72, 1080, 1920)
SUCCESS: segmentation.npz created in Colab. DOWNLOAD THIS FILE NOW.
