<a href="https://colab.research.google.com/github/Minyst/ML_DL_Portfolio_KR/blob/main/SAM/SAM2_with_YOLO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything-2.git

In [None]:
!pip install ultralytics

In [None]:
import torch
import cv2
import numpy as np
from sam2.build_sam import build_sam2_video_predictor_hf
import os
import tempfile
from ultralytics import YOLO

def convert_video_to_jpeg_sequence(video_path, output_folder, target_size=(640, 360)):
    cap = cv2.VideoCapture(video_path)
    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        frame = cv2.resize(frame, target_size)
        cv2.imwrite(os.path.join(output_folder, f"{frame_count:06d}.jpg"), frame)
    cap.release()
    return frame_count

video_path = "/kaggle/input/LA vs SD.mp4"
target_size = (640, 360)
yolo_model = YOLO('yolov10m.pt')

with tempfile.TemporaryDirectory() as temp_dir:
    frame_count = convert_video_to_jpeg_sequence(video_path, temp_dir, target_size)
    model_id = "facebook/sam2-hiera-small"
    predictor = build_sam2_video_predictor_hf(model_id)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter('output.mp4', fourcc, 6.0, target_size)

    with torch.inference_mode():
        state = predictor.init_state(temp_dir)
        for frame_idx in range(0, frame_count, 5):
            current_frame = cv2.imread(os.path.join(temp_dir, f"{frame_idx+1:06d}.jpg"))

            results = yolo_model(current_frame)

            frame_mask = np.zeros(current_frame.shape[:2], dtype=np.uint8)

            for result in results:
                boxes = result.boxes
                for box in boxes:
                    x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
                    label = result.names[int(box.cls)]

                    if label in ['person', 'sports ball', 'baseball bat', 'baseball glove']:
                        input_point = np.array([[int((x1 + x2) / 2), int((y1 + y2) / 2)]])
                        input_label = np.array([1])
                        _, _, masks = predictor.add_new_points_or_box(
                            state,
                            frame_idx=frame_idx,
                            obj_id=0,
                            points=input_point,
                            labels=input_label
                        )

                        if masks is not None and len(masks) > 0:
                            mask = masks[0].cpu().numpy().squeeze()
                            if np.isfinite(mask).all() and np.ptp(mask) > 0:
                                mask = (mask > 0.5).astype(np.uint8)
                                if label == 'person':
                                    frame_mask[mask > 0] = 1
                                elif label == 'sports ball':
                                    frame_mask[mask > 0] = 2
                                elif label == 'baseball bat':
                                    frame_mask[mask > 0] = 3
                                elif label == 'baseball glove':
                                    frame_mask[mask > 0] = 4

            overlay = np.zeros_like(current_frame)
            overlay[frame_mask == 1] = [0, 255, 0]    # Green for player
            overlay[frame_mask == 2] = [0, 0, 255]    # Red for baseball
            overlay[frame_mask == 3] = [255, 0, 0]    # Blue for bat
            overlay[frame_mask == 4] = [255, 255, 0]  # Yellow for baseball glove

            result_frame = cv2.addWeighted(current_frame, 0.7, overlay, 0.3, 0)

            out.write(result_frame)

            if frame_idx % 50 == 0:
                torch.cuda.empty_cache()

    out.release()