In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import imutils

In [2]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "C:\\Users\\gerritsm\\Github\\sam2\\checkpoints\\sam2.1_hiera_large.pt"
model_cfg = "C:\\Users\\gerritsm\\Github\\sam2\\sam2\\configs\\sam2.1\\sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")

In [3]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

In [None]:
# `frames_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
frames_dir = ".\\atlanticFlumeTests\\video_20241212_144325_jpg_test"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(frames_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

inference_state = predictor.init_state(video_path=frames_dir)

In [5]:
### SEGMENT AND TRACK OBJECT
#reset predictor
predictor.reset_state(inference_state)

In [None]:
ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Define point(s) as a prompt
points = np.array([[250, 980], [950, 850], [1750, 1000], [2000, 780]], dtype=np.float32)
labels = np.array([1, 1, 1, 1], np.int32) # 0: negative 1: positive labels

# Conditioning
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(frames_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
# Run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

In [35]:
# Assuming video_segments, frame_names, and frames_dir are already defined
vis_frame_stride = 1
output_video_path = "p:\\computervisiondeltares\\testing_CNNprotocol\\atlanticFlumeTests\\output_{}.mp4".format(os.path.basename(frames_dir))
frame_rate = 10  # Define the frame rate for the output video
resize_size = 1800

# Get frame size from the first image
first_frame_path = os.path.join(frames_dir, frame_names[0])
first_frame = Image.open(first_frame_path)
composite_image = np.hstack((first_frame,first_frame))
composite_image = imutils.resize(composite_image, width = resize_size)

frame_height, frame_width, _ = composite_image.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for mp4
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))

def show_contours_on_frame(mask, frame):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    frame_with_contours = cv2.drawContours(frame, contours, -1, (0, 255, 0), 10)
    return frame_with_contours

for out_frame_idx in tqdm(range(0, len(frame_names), vis_frame_stride)):
    frame = cv2.imread(os.path.join(frames_dir, frame_names[out_frame_idx]))
    result_frame = frame.copy()

    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        mask = out_mask[0]

        if mask.dtype == bool:
            mask = mask.astype(np.uint8) * 255

        _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

        result_frame = show_contours_on_frame(binary_mask, result_frame)

    # Create a composite image by concatenating the original and detected images side by side
    composite_image = np.hstack((frame,result_frame))
    composite_image = imutils.resize(composite_image, width = resize_size)

    # Add frame ID text in the top left corner
    font = cv2.FONT_HERSHEY_SIMPLEX
    text = f"Frame: {out_frame_idx}"
    position = (10, 30)  # Top left corner
    font_scale = 1
    font_color = (0, 255, 0)  # Green color in BGR
    thickness = 2
    line_type = cv2.LINE_AA

    cv2.putText(composite_image, text, position, font, font_scale, font_color, thickness, line_type)
    # Write the frame to the video
    out.write(composite_image)

    # Display the result
    cv2.imshow('Contours on Frame', composite_image)
    cv2.waitKey(10)

# Release the VideoWriter object
out.release()
cv2.destroyAllWindows()

100%|██████████| 100/100 [00:06<00:00, 15.06it/s]
