In [None]:
import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

from semantic_sam import (
    prepare_image,
    plot_results,
    build_semantic_sam,
    SemanticSamAutomaticMaskGenerator,
)

from sam2_main.sam2.build_sam import build_sam2_video_predictor

  @autocast(enabled=False)
  @autocast(enabled=False)


In [2]:
def extract_regions_with_Semantic_SAM(image_sequence, auto_mask_generator, freq=1):
    masks_candidates= []

    for t, image in enumerate(image_sequence):
            
        # generate masks using Semantic-SAM
        print(f"⏳ start auto mask generation using Semantic-SAM on frame {t*freq}")
        current_masks = auto_mask_generator.generate(image)
        print(f"✅ auto mask generation finish, {len(current_masks)} masks generated")

        # add frame index to each mask for further SAM2 tracking
        for mask in current_masks:
            mask["frame"] = t*10
        
        masks_candidates.append(current_masks)

    # tracked_masks = [mask for frame_masks in tracked_masks for mask in frame_masks]  # flatten the list
    print("✅ Semantic-SAM processed all frames")
    
    return masks_candidates

In [3]:
# tool functions
def compute_iou(mask1, mask2):
    # If masks are already the same size, compute IoU directly
    if mask1.shape == mask2.shape:
        intersection = np.logical_and(mask1, mask2).sum()
        union = np.logical_or(mask1, mask2).sum()
        return intersection / union if union != 0 else 0

    # Convert numpy arrays to torch tensors
    m1 = torch.from_numpy(mask1.astype(np.float32))
    m2 = torch.from_numpy(mask2.astype(np.float32))

    # Add batch and channel dimensions (B, C, H, W)
    m1 = m1.unsqueeze(0).unsqueeze(0)
    m2 = m2.unsqueeze(0).unsqueeze(0)

    # Resize m2 to match m1's dimensions using bilinear interpolation
    m2_resized = torch.nn.functional.interpolate(m2, size=m1.shape[-2:], mode='bilinear', align_corners=False)

    # Threshold the resized mask to make it boolean again and remove batch/channel dims
    m2_resized = (m2_resized > 0.5).squeeze()
    m1 = m1.squeeze() # Remove batch/channel dims

    # Now that they are the same size, compute IoU
    intersection = torch.logical_and(m1, m2_resized).sum().item()
    union = torch.logical_or(m1, m2_resized).sum().item()

    return intersection / union if union != 0 else 0

# loading and selecting scene color images into one list
def load_scene_img(datapath, freq=1):
    
    # scene sampling parameters
    start_idx = 0
    end_idx = 200
    selected_img = []

    # select the wanted image sequence 
    color_img = sorted(os.listdir(datapath))
    selected_img_idx = [color_img[i] for i in range(start_idx, end_idx, freq)]    
    print(f"Selected image number: {len(selected_img_idx)}")
    print("selected frame:", selected_img_idx) 

    # pre-processing images
    for img_name in selected_img_idx:
        image_path = os.path.join(datapath, img_name)
        _, image = prepare_image(image_path)
        selected_img.append(image)

    return selected_img

def bbox_transform(bboxes):
    for bbox in bboxes:
        x, y, w, h = bbox
        bbox[0] = x
        bbox[1] = y
        bbox[2] = x + w
        bbox[3] = y + h
    return bboxes

def bbox_scalar_fit(bboxes, scalar_x, scalar_y):
    for bbox in bboxes:
        bbox[0] = int(bbox[0] * scalar_x)
        bbox[1] = int(bbox[1] * scalar_y)
        bbox[2] = int(bbox[2] * scalar_x)
        bbox[3] = int(bbox[3] * scalar_y)
    return bboxes

def mask_candidate_refine(mask_candidates, min_area=300, max_area=50000, stability_score=0.9):
    final_candidates = []
    filtered_out = 0

    for frame in mask_candidates:
        refined_candidates = []
        for mask in frame:
            area = mask['area']
            if min_area <= area <= max_area:
                if mask['stability_score'] >= stability_score:
                    refined_candidates.append(mask)
            else:
                filtered_out += 1
        else:
            filtered_out += 1
        
        final_candidates.append(refined_candidates)

    print(f"Filtered out {filtered_out} masks based on area and stability score.")

    return final_candidates

def store_output(output_path, data_path, video_segments):
    """
    Stores the masked frames for each object in separate folders.

    Args:
        output_path (str): The root directory to save the output folders.
        data_path (str): The path to the directory of original video frames.
        video_segments (dict): A dictionary containing the segmentation masks.
                               It is structured as {frame_idx: {obj_id: mask_array, ...}}.
    """
    print(f"⏳ Storing output masks to {output_path}")
    os.makedirs(output_path, exist_ok=True)

    # Find all unique object IDs from the tracking results
    all_obj_ids = set()
    for frame_masks in video_segments.values():
        all_obj_ids.update(frame_masks.keys())

    # Create a folder for each object ID
    for obj_id in all_obj_ids:
        obj_folder = os.path.join(output_path, str(obj_id))
        os.makedirs(obj_folder, exist_ok=True)
    
    print(f"Found {len(all_obj_ids)} unique objects to store.")

    # Get a sorted list of the original frame filenames
    try:
        frame_filenames = sorted([f for f in os.listdir(data_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    except FileNotFoundError:
        print(f"Error: The data_path '{data_path}' was not found.")
        return

    # Process each frame
    for frame_idx, frame_filename in enumerate(tqdm(frame_filenames, desc="Applying masks and saving frames")):
        if frame_idx in video_segments:
            # Load the original image
            original_img_path = os.path.join(data_path, frame_filename)
            original_img = Image.open(original_img_path).convert('RGB')
            original_img_array = np.array(original_img)

            # Get the masks for the current frame
            frame_masks = video_segments[frame_idx]
            for obj_id, mask in frame_masks.items():
                # Ensure the mask is boolean
                bool_mask = mask > 0
                bool_mask = np.squeeze(bool_mask)
                
                # Apply the mask to the original image using broadcasting
                # np.newaxis adds a new dimension to the mask (H, W) -> (H, W, 1)
                # so it can be multiplied with the image array (H, W, 3)
                masked_img_array = original_img_array * bool_mask[:, :, np.newaxis]

                # Convert array back to image and save
                masked_img = Image.fromarray(masked_img_array.astype(np.uint8))
                save_filename = os.path.splitext(frame_filename)[0] + ".png" # Save as png
                save_path = os.path.join(output_path, str(obj_id), save_filename)
                masked_img.save(save_path)

    print(f"✅ Storing output finished successfully.")

In [11]:
def SAM2_tracking(data_path, SAM2_video_predictor, mask_candidates, freq=1):
    with torch.inference_mode(), torch.autocast("cuda"):
        iou_threshold = 0.6
        obj_count = 1  # to keep track of the total number of objects added
        final_video_segments = {}

        # Initialize SAM2 video predictor state
        inference_state = SAM2_video_predictor.init_state(video_path=data_path)

        # Calculate scaling factors based on the first frame's mask size
        scalar_x = inference_state['video_width'] / mask_candidates[0][0]['segmentation'].shape[1]
        scalar_y = inference_state['video_height'] / mask_candidates[0][0]['segmentation'].shape[0]
        
        for frame, frame_masks in enumerate(mask_candidates):
            if frame == 3:
                break  # demo only
            # Re-initialize the state to process only the new masks for this iteration
            inference_state = SAM2_video_predictor.init_state(video_path=data_path)
            
            if frame > 0:
                untracked_masks = []
                for mask in frame_masks:
                    is_tracked = False
                    if any(compute_iou(prev_mask, mask["segmentation"]) > iou_threshold for prev_mask in final_video_segments[frame*freq].values()):
                        is_tracked = True
                    if not is_tracked:
                        untracked_masks.append(mask)

                print(f"{len(untracked_masks)} untracked masks found on frame: {frame*freq}.")
                bboxes = [mask['bbox'] for mask in untracked_masks]
            else:
                bboxes = [mask['bbox'] for mask in frame_masks]
                
            # Convert Semantic-SAM bbox format (XYWH) to SAM2 bbox format (x1y1x2y2)
            bboxes = bbox_transform(bboxes)
                
            # Scale the bboxes to fit the SAM2 video input image size
            bboxes = bbox_scalar_fit(bboxes, scalar_x, scalar_y)
                
            # Apply SAM2 to generate masklets across frames
            for bbox in bboxes:
                if obj_count % 3 == 0:
                    obj_count += 1
                    break  # demo onlyirst 3 mask candidates for tracking

                ann_frame_idx = frame*freq
                ann_obj_id = obj_count
                obj_count += 1

                print(f"⏳ start adding masklet {obj_count-1} from frame {ann_frame_idx} as prompt for SAM2")
                _, out_obj_ids, out_mask_logits = SAM2_video_predictor.add_new_points_or_box(inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=bbox)

            # Propagate the masks through the video
            video_segments = {}  # video_segments contains the per-frame segmentation results
            for out_frame_idx, out_obj_ids, out_mask_logits in SAM2_video_predictor.propagate_in_video(inference_state):
                video_segments[out_frame_idx] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }
            
            
            for out_frame_idx, frame_data in video_segments.items():
                if out_frame_idx not in final_video_segments:
                    final_video_segments[out_frame_idx] = {}
                final_video_segments[out_frame_idx].update(frame_data)
        
        print("✅ Iterative video propagation finished.")
        return final_video_segments

In [5]:
data_path = "Multiscan/scene0065_00/color/JPG"
freq = 10 # image sampling frequency

In [6]:
# loading scene color image 
print("⏳ start loading scene images")
image_sequence= load_scene_img(data_path, freq=10)
print("✅ loading scene image finish")

⏳ start loading scene images
Selected image number: 20
selected frame: ['1400.jpg', '1410.jpg', '1420.jpg', '1430.jpg', '1440.jpg', '1450.jpg', '1460.jpg', '1470.jpg', '1480.jpg', '1490.jpg', '1500.jpg', '1510.jpg', '1520.jpg', '1530.jpg', '1540.jpg', '1551.jpg', '1561.jpg', '1571.jpg', '1581.jpg', '1591.jpg']
✅ loading scene image finish


In [7]:
# loading Semantic-SAM model
print("⏳ start loading Semantic-SAM model")
model_type = 'L'
ckpt_path = "ckpts/swinl_only_sam_many2many.pth"
auto_mask_generator = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type=model_type, ckpt=ckpt_path))
print("✅ Semantic-SAM model loaded successfully")

⏳ start loading Semantic-SAM model


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
$UNUSED$ criterion.empty_weight, Ckpt Shape: torch.Size([2])


✅ Semantic-SAM model loaded successfully


In [8]:
# loading SAM2 video predictor
print("⏳ start loading SAM2 model")
checkpoint = "sam2_main/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
SAM2_video_predictor = build_sam2_video_predictor(model_cfg, checkpoint)
print("✅ SAM2 model loaded successfully")

⏳ start loading SAM2 model
✅ SAM2 model loaded successfully


In [9]:
# implement algorithm 1's first-half
# generate mask candidates using Semantic-SAM
tracked_masks = extract_regions_with_Semantic_SAM(image_sequence, auto_mask_generator, freq)
tracked_masks = mask_candidate_refine(tracked_masks) # refine mask candidates based on area and stability score

⏳ start auto mask generation using Semantic-SAM on frame 0
✅ auto mask generation finish, 83 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 10
✅ auto mask generation finish, 74 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 20
✅ auto mask generation finish, 75 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 30
✅ auto mask generation finish, 73 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 40
✅ auto mask generation finish, 62 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 50
✅ auto mask generation finish, 72 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 60
✅ auto mask generation finish, 61 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 70
✅ auto mask generation finish, 75 masks generated
⏳ start auto mask generation using Semantic-SAM on frame 80
✅ auto mask generation finish, 75 masks generated
⏳ start aut

In [12]:

video_segments = SAM2_tracking(data_path, SAM2_video_predictor, tracked_masks, freq)

frame loading (JPEG): 100%|██████████| 200/200 [00:06<00:00, 31.65it/s]
frame loading (JPEG): 100%|██████████| 200/200 [00:05<00:00, 34.25it/s]


⏳ start adding masklet 1 from frame 0 as prompt for SAM2
⏳ start adding masklet 2 from frame 0 as prompt for SAM2


propagate in video: 100%|██████████| 200/200 [03:35<00:00,  1.08s/it]
frame loading (JPEG): 100%|██████████| 200/200 [00:09<00:00, 22.19it/s]


62 untracked masks found on frame: 10.
⏳ start adding masklet 4 from frame 10 as prompt for SAM2
⏳ start adding masklet 5 from frame 10 as prompt for SAM2


propagate in video: 100%|██████████| 190/190 [00:28<00:00,  6.77it/s]
frame loading (JPEG): 100%|██████████| 200/200 [00:06<00:00, 31.89it/s]


60 untracked masks found on frame: 20.
⏳ start adding masklet 7 from frame 20 as prompt for SAM2
⏳ start adding masklet 8 from frame 20 as prompt for SAM2


propagate in video: 100%|██████████| 180/180 [01:47<00:00,  1.68it/s]

✅ Iterative video propagation finished.





In [14]:
# store the video_segments results
output_path = "output/test_scene_0065_00"
store_output(output_path, data_path, video_segments)

⏳ Storing output masks to output/test_scene_0065_00
Found 6 unique objects to store.


Applying masks and saving frames: 100%|██████████| 200/200 [00:51<00:00,  3.88it/s]

✅ Storing output finished successfully.



