In [None]:
# Install required dependencies and clone SAM2 repository
!pip install git+https://github.com/facebookresearch/sam2.git

# Download SAM2 model checkpoint
!mkdir -p ../checkpoints/
!wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [None]:
# Define SAM2 model and checkpoint
sam2_checkpoint = "../checkpoints/sam2.1_hiera_small.pt"
model_cfg = "sam2/configs/sam2.1/sam2.1_hiera_s.yaml"

# Load the SAM2 model
device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

# Initialize the automatic mask generator
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)


In [None]:
# Define the function to display segmentation masks
def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)

    ax.imshow(img)


In [None]:
# Load an example image
image_path = '/content/splits_final_deblurred/train/data/04_frame_036100.PNG'
image = Image.open(image_path).convert("RGB")
image_np = np.array(image)

# Generate masks using SAM2
with torch.no_grad():
    masks = mask_generator.generate(image_np)

# Visualize the generated masks
plt.figure(figsize=(12, 12))
plt.imshow(image_np)
show_anns(masks)
plt.axis('off')
plt.show()
