### Notebook for testing the newest segment anything model by Facebook AI

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageFilter
from PIL import ImageOps
import os
from tqdm import tqdm
from sam2.sam2_video_predictor import SAM2VideoPredictor

ROOT_DIR = "../.."
VIDEO_NAME = "face_test"

# Path to jpegs
source_dir = os.path.join(ROOT_DIR, VIDEO_NAME, "frames")
video_dir = os.path.join(ROOT_DIR, VIDEO_NAME, "crop_frames")
os.makedirs(video_dir, exist_ok=True)

# Show the first frame
frame = Image.open(os.path.join(source_dir, "0001.jpg"))
plt.imshow(frame)
plt.show()

RESIZE = (144, 312)
CROP = (0, 110, 144, 280)

# See if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

if device.type == "cuda":
    # Use bfloat16 for faster inference
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

In [None]:
# predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
# predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-base-plus")
# predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-small")
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
predictor = predictor.to(device)

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

In [None]:
def process_frame(image, crop=CROP, resize=RESIZE):
    # w, h = image.size
    # image = image.resize((w // 2, h // 2))
    image = image.resize(resize)
    image = image.crop(crop)
    # Apply slight gaussian blur
    # image = image.filter(ImageFilter.GaussianBlur(1))
    return image

def reverse_process_frame(image, size, crop=CROP, resize=RESIZE):
    """
    Take a processed frame and reverse the process,
    for example gets 144x170 image, pads to 144x312, then resizes to wxh
    """
    # Zero-pad the image to reverse the crop
    re_w, re_h = resize
    padding = (crop[0], crop[1], re_w - crop[2], re_h - crop[3])
    image = ImageOps.expand(image, padding)
    # Resize the image back to original size
    image = image.resize(size)
    return image

# Crop all frames and save to video_dir
for i, frame in tqdm(enumerate(os.listdir(source_dir))):
    image = Image.open(os.path.join(source_dir, frame))
    image = process_frame(image)
    image.save(os.path.join(video_dir, frame))

In [None]:
# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

In [None]:
# !python sam_interactive_init.py --root_dir "../.." # This doesn't work for the interactive plot
print(os.path.abspath("./sam_interactive_init.py"))
RAN_INTERACTIVE_SAM = True

In [None]:
if RAN_INTERACTIVE_SAM:
    state = torch.load(os.path.join(ROOT_DIR, VIDEO_NAME, "init_sam_state.pth"), weights_only=False)
    inference_state = state["inference_state"]
else:
    inference_state = predictor.init_state(video_path=video_dir,
        offload_video_to_cpu=True,
        offload_state_to_cpu=True,
        async_loading_frames=False,
    )

In [None]:
if not RAN_INTERACTIVE_SAM:
    # Run whenever we want to reset the tracking
    predictor.reset_state(inference_state)

In [None]:
if not RAN_INTERACTIVE_SAM:
    ann_frame_idx = 0  # the frame index we interact with
    ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

    # Let's add a positive click at (x, y) to get started
    points = np.array([
        [32, 20],
        [32, 40],
        [32, 45],
        [20, 60], 
        [30, 70],
        # [25, 44],
        # [40, 44],
    ], dtype=np.float32) * 2
    # for labels, `1` means positive click and `0` means negative click
    labels = np.array([
        1,
        1, 
        0,
        0,
        0,
        # 1,
        # 1,
    ], np.int32)
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )
else:
    ann_frame_idx = state["frame_idx"]
    ann_obj_id = state["obj_id"]
    points = state["points"]
    labels = state["labels"]
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

In [None]:
# render the segmentation results every few frames
vis_frame_stride = 10
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

In [None]:
if False:
    ann_frame_idx = 70  # further refine some details on this frame
    ann_obj_id = 1  # give a unique id to the object we interact with (it can be any integers)

    # show the segment before further refinement
    plt.figure(figsize=(9, 6))
    plt.title(f"frame {ann_frame_idx} -- before refinement")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
    show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)

    # Let's add a negative click on this frame at (x, y) to refine the segment
    points = np.array([
        [40, 55],
        [42, 30],
    ], dtype=np.float32)
    # for labels, `1` means positive click and `0` means negative click
    labels = np.array([
        0,
        1,
    ], np.int32)
    _, _, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

    # show the segment after the further refinement
    plt.figure(figsize=(9, 6))
    plt.title(f"frame {ann_frame_idx} -- after refinement")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
    show_points(points, labels, plt.gca())
    show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)

In [None]:
mask_dir = os.path.join(ROOT_DIR, "face_test", "mask_frames")
os.makedirs(mask_dir, exist_ok=True)

# Get original size from source_dir
image = Image.open(os.path.join(source_dir, frame_names[0]))
original_size = image.size

# save the segmentation results to disk
for out_frame_idx, out_obj_ids in video_segments.items():
    for out_obj_id, out_mask in out_obj_ids.items():
        mask = (out_mask * 255).astype(np.uint8)
        if mask.ndim == 2:  # Ensure the mask has 3 dimensions (H, W, 1)
            mask = np.expand_dims(mask, axis=-1)
        mask_image = Image.fromarray(mask.squeeze(), mode='L')
        mask_image = reverse_process_frame(mask_image, original_size)
        mask_image.save(os.path.join(mask_dir, f"{out_frame_idx:04d}.png"))
        break

# create a video from the segmentation results
video_save_path = os.path.join(ROOT_DIR, "face_test", "mask.mp4")
fps = 15
os.system(f"ffmpeg -y -r {fps} -f image2 -i {mask_dir}/%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {video_save_path}")

In [None]:
mask_blur_dir = os.path.join(ROOT_DIR, "face_test", "mask_blur_frames")
os.makedirs(mask_blur_dir, exist_ok=True)

# Blur the mask in the original frames
for out_frame_idx, out_obj_ids in video_segments.items():
    image = Image.open(os.path.join(source_dir, frame_names[out_frame_idx]))
    mask = out_obj_ids[ann_obj_id]
    mask = (mask * 255).astype(np.uint8)
    mask = np.expand_dims(mask, axis=-1)
    mask_image = Image.fromarray(mask.squeeze(), mode='L')
    mask_image = reverse_process_frame(mask_image, original_size)
    mask_image = mask_image.filter(ImageFilter.GaussianBlur(5))
    # Blur the area covered by the mask in the original image
    blur_image = image.copy()
    # blur_image = image.filter(ImageFilter.GaussianBlur(20))

    # Pixelate the area covered by the mask in the original image (do this by resizing the image back and forth)
    pixelate_factor = 50
    blur_image = blur_image.resize((original_size[0] // pixelate_factor, original_size[1] // pixelate_factor))
    blur_image = blur_image.resize(original_size, Image.NEAREST)

    # Paste the mask part of the blurred image on the original image
    image.paste(blur_image, mask=mask_image)
    image.save(os.path.join(mask_blur_dir, f"{out_frame_idx:04d}.jpg"))
    # break

# Plot first frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {0}")
plt.imshow(Image.open(os.path.join(mask_blur_dir, "0000.jpg")))
plt.show()

# create a video from the segmentation results
video_save_path = os.path.join(ROOT_DIR, "face_test", "mask_blur.mp4")
fps = 15
os.system(f"ffmpeg -y -r {fps} -i {mask_blur_dir}/%04d.jpg -vcodec libx264 -crf 25 -pix_fmt yuv420p {video_save_path}")