# SAM-2 Video Interactive Segmentation — Cells 1–3

This notebook contains the **first three cells** of the SAM-2 course demo:

1. **Install dependencies** (SAM2 + tooling)
2. **Prepare frames**: download a small example video (or use your own), downscale & subsample, extract frames into `frames/`
3. **Batch demo**: load **`SAM2VideoPredictor`**, add a box/point prompt, propagate through the video, overlay masks, and export **`sam2_result.mp4`**

> These cells are designed to run on Colab (GPU preferred, CPU also works). You can copy this notebook to GitHub as-is.


## Cell 1 — Install dependencies

In [None]:
!pip -q install "git+https://github.com/facebookresearch/sam2.git" \
                huggingface_hub opencv-python imageio[ffmpeg] matplotlib tqdm requests
print("✅ Dependencies installed")

## Cell 2 — Download video & extract frames (downscale + subsample)

- Replace `VIDEO_URL` with your video link **or** upload your own video and set `VIDEO_MP4` accordingly.
- Frames are written to `frames/0.jpg, 1.jpg, ...`.
- Adjust `TARGET_LONG_EDGE`, `STEP`, `MAX_FRAMES` to control speed/memory.

In [None]:
import os, shutil, cv2, requests, glob
from pathlib import Path

# ---- Source video ----
VIDEO_URL = "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/video/Coco%20Walking%20in%20Berkeley.mp4"
VIDEO_MP4 = "demo.mp4"   # if you upload a file, set this to your filename

# ---- Frame extraction params ----
TARGET_LONG_EDGE = 640   # downscale long edge to this (helps avoid OOM)
STEP = 2                 # take 1 frame every STEP frames (2–4 saves time)
MAX_FRAMES = 200         # hard limit for speed

FRAMES_DIR = Path("frames")
FRAMES_DIR.mkdir(exist_ok=True)
for f in FRAMES_DIR.glob("*.jpg"): f.unlink()

# ---- Download video if not present ----
if not Path(VIDEO_MP4).exists():
    print("Downloading sample video…")
    r = requests.get(VIDEO_URL, stream=True, timeout=60)
    r.raise_for_status()
    with open(VIDEO_MP4, "wb") as f:
        for chunk in r.iter_content(chunk_size=1<<20):
            if chunk: f.write(chunk)
print("Video:", VIDEO_MP4, f"({os.path.getsize(VIDEO_MP4)/1e6:.2f} MB)")

# ---- Extract frames ----
cap = cv2.VideoCapture(VIDEO_MP4)
assert cap.isOpened(), "Cannot open video"
fps = cap.get(cv2.CAP_PROP_FPS) or 30
idx = 0; out_idx = 0
while True:
    ret, frame = cap.read()
    if not ret: break
    if idx % STEP == 0 and out_idx < MAX_FRAMES:
        h, w = frame.shape[:2]
        long_edge = max(h, w)
        scale = TARGET_LONG_EDGE / long_edge
        if scale < 1.0:
            frame = cv2.resize(frame, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
        cv2.imwrite(str(FRAMES_DIR / f"{out_idx}.jpg"), frame)
        out_idx += 1
    idx += 1
cap.release()

first = cv2.imread(str(FRAMES_DIR/"0.jpg"))
h0, w0 = first.shape[:2]
print(f"✅ Frames: {out_idx}  |  size≈{w0}x{h0}  |  fps≈{fps:.1f}")

## Cell 3 — Run SAM-2 on frames and export `sam2_result.mp4`

This cell loads `SAM2VideoPredictor` with the **tiny** checkpoint, adds an initial **box** on frame 0, optionally a corrective **point** on a later frame, then **propagates** masks across the video and overlays them while streaming to an MP4. Adjust the box/point as needed for your video.

In [None]:
import os, cv2, numpy as np, torch
from pathlib import Path
from sam2.sam2_video_predictor import SAM2VideoPredictor
import matplotlib.pyplot as plt

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

FRAMES_DIR = Path("frames")
frame_files = sorted(FRAMES_DIR.glob("*.jpg"), key=lambda p:int(p.stem))
assert len(frame_files)>0, "No frames found. Run Cell 2 first."

# Load predictor (tiny = memory friendly). Alternatives: 'facebook/sam2.1-hiera-tiny', 'facebook/sam2-hiera-small'
MODEL_ID = "facebook/sam2-hiera-tiny"
predictor = SAM2VideoPredictor.from_pretrained(MODEL_ID, device=device)
print("Loaded:", MODEL_ID)

# Init video state
state = predictor.init_state(video_path=str(FRAMES_DIR))
img0 = cv2.imread(str(frame_files[0]))
H0, W0 = img0.shape[:2]

# --- Prompts: a robust box on frame 0, plus an optional positive point later ---
obj_id = 1
box0 = np.array([0.45*W0, 0.15*H0, 0.75*W0, 0.90*H0], np.float32)  # adjust to your video
predictor.add_new_points_or_box(state, frame_idx=0, obj_id=obj_id, box=box0)

fi_fix = min(12, len(frame_files)-1)
pt_fix = np.array([[0.60*W0, 0.25*H0]], np.float32)
lb_fix = np.array([1], np.int32)
predictor.add_new_points_or_box(state, frame_idx=fi_fix, obj_id=obj_id, points=pt_fix, labels=lb_fix)

# --- Stream propagation & export MP4 ---
OUT_MP4 = "sam2_result.mp4"
fps_out = 12  # output FPS
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(OUT_MP4, fourcc, fps_out, (W0, H0))
print("Writing:", OUT_MP4)

next_write = 0
pending = {}
color = (34, 204, 136)  # RGB for overlay

def overlay_mask(rgb, mask, color=(34,204,136), alpha=0.65):
    if mask is None: return rgb
    m = mask.astype(bool)
    over = rgb.astype(np.float32)
    over[m] = (1 - alpha) * over[m] + alpha * np.array(color, dtype=np.float32)
    return over.astype(np.uint8)

for fi, out_obj_ids, out_mask_logits in predictor.propagate_in_video(state):
    rgb = cv2.cvtColor(cv2.imread(str(frame_files[fi])), cv2.COLOR_BGR2RGB)
    mask = None
    for k in range(len(out_obj_ids)):
        if int(out_obj_ids[k]) == obj_id:
            mk = (out_mask_logits[k] > 0).detach().cpu().numpy().squeeze().astype(np.uint8)
            if mk.shape != rgb.shape[:2]:
                mk = cv2.resize(mk, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
            mask = mk; break
    over = overlay_mask(rgb, mask, color=color, alpha=0.65)
    pending[fi] = cv2.cvtColor(over, cv2.COLOR_RGB2BGR)
    while next_write in pending:
        writer.write(pending.pop(next_write)); next_write += 1

for fi in sorted(pending):
    writer.write(pending[fi])
writer.release()
print("✅ Saved:", OUT_MP4)

# --- Quick preview of a few frames ---
import matplotlib.pyplot as plt
cap = cv2.VideoCapture(OUT_MP4); frames=[]
while True:
    ret, f = cap.read()
    if not ret: break
    frames.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
cap.release()
sel = [0, min(12, len(frames)-1), max(0, len(frames)-3)]
plt.figure(figsize=(12,4))
for i,fi in enumerate(sel):
    plt.subplot(1,3,i+1); plt.imshow(frames[fi]); plt.title(f"Frame {fi}"); plt.axis('off')
plt.show()