In [16]:
from sam2.sam2_video_predictor import SAM2VideoPredictor
import skimage as ski
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize, resize, rotate
import matplotlib.pyplot as plt
from tqdm import tqdm
from cellpose import models
from skimage.measure import regionprops

In [None]:
from pathlib import Path
from tifffile import imread

root_path = Path("../../pvc/scratch/interaction_cells/datasets/")
print("Loading image...")
image = imread(root_path / "series003_cCAR_tumor.tif")
print("Loaded image")
image = image[:, 0, ...]

In [3]:
print("Resizing...")
example_image =  ski.transform.resize(image, (image.shape[0], image.shape[1] // 2, image.shape[2] // 2), anti_aliasing=True)


# Load corresponding masks


def percentile_norm(image):
    """Normalize the image to the 99th percentile."""
    image = image.astype(np.float32)
    for t, frame in tqdm(enumerate(image), desc="Normalizing..."):
        p99 = np.percentile(frame, 99)
        p1 = np.percentile(frame, 1)
        frame = (frame - p1) / (p99 - p1)
        frame = (frame - np.min(frame)) / (np.max(frame) - np.min(frame))
        # image[t] = np.clip(frame, 0, 1)
        image[t] = frame
        # print(f"Min {np.min(image[t])}, Max {np.max(image[t])}, Mean {np.mean(image[t])}")
        # print(f"Min {np.min(frame)}, Max {np.max(frame)}, Mean {np.mean(frame)}")
    return image

example_image = percentile_norm(example_image)

Resizing...


162it [00:01, 91.30it/s]


In [9]:
cellpose_model = models.CellposeModel(gpu=True, model_type='cyto3')

100%|██████████| 25.3M/25.3M [00:02<00:00, 12.6MB/s]


In [13]:
masks = np.zeros_like(example_image, dtype=np.uint16)
for i in tqdm(range(example_image.shape[0])):
    masks[i], flows, styles = cellpose_model.eval(
        example_image[i], diameter=45, do_3D=False, channels=[0, 0], normalize=True, flow_threshold=0.6, cellprob_threshold=-1.0
    )

100%|██████████| 162/162 [01:18<00:00,  2.06it/s]


In [17]:
# get df of centroids per frame for each label

def get_centroids(masks):
    """Get centroids of each label in the masks."""
    centroids = []
    for t, frame in enumerate(masks):
        props = regionprops(frame)
        for prop in props:
            centroids.append({
                "t": t,
                "label": prop.label,
                "centroid": prop.centroid
            })
    return pd.DataFrame(centroids)

In [18]:
centroids_df = get_centroids(masks)

In [20]:
masks.shape

(162, 706, 706)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Display the original image (frame 0)
axes[0].imshow(example_image[0], cmap='gray')
axes[0].set_title("Original Image (Frame 0)")
axes[0].axis('off')

# Display the mask (frame 0)
axes[1].imshow(masks[0], cmap='jet', alpha=0.7)
axes[1].set_title("Mask (Frame 0)")
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
from dataclasses import dataclass
import numpy as np

@dataclass
class Prompt:
    original_label: int
    centroid: np.ndarray
    positive_prompts_coords: np.ndarray
    negative_prompts_coords: np.ndarray

    def get_in_sam_format(self):
        points = np.zeros((self.positive_prompts_coords.shape[0] + self.negative_prompts_coords.shape[0], 2), dtype=np.float32)
        points[:self.positive_prompts_coords.shape[0]] = self.positive_prompts_coords
        if self.negative_prompts_coords.shape[0] > 0:
            points[self.positive_prompts_coords.shape[0]:] = self.negative_prompts_coords

        prompt_type = np.zeros((self.positive_prompts_coords.shape[0] + self.negative_prompts_coords.shape[0],), dtype=np.int32)
        prompt_type[:self.positive_prompts_coords.shape[0]] = 1  # Positive prompts
        if self.negative_prompts_coords.shape[0] > 0:
            prompt_type[self.positive_prompts_coords.shape[0]:] = 0  # Negative prompts
        return points, prompt_type


def sample_positive_and_negative_prompts(masks, centroids_df, n_samples_pos=10, n_samples_neg=10, vicinity=10, min_distance=5):
    """Samples positive points strictly inside the mask and negative points outside the mask within a vicinity."""
    n_labels = len(np.unique(masks[0])) - 1  # Ignore background label 0
    prompts = []  # List to store Prompt objects for each label

    for label in np.unique(masks[0])[1:]:  # Skip background label 0
        # Get the centroid for the current label
        label_centroids = centroids_df[(centroids_df["label"] == label) & (centroids_df["t"] == 0)]["centroid"].values
        assert len(label_centroids) == 1, f"Label {label} has {len(label_centroids)} centroids, instead of being unique."
        centroid = label_centroids[0]

        # Get the mask for the current label
        mask = masks[0] == label

        # Get all coordinates strictly within the mask (positive points)
        mask_coords = np.column_stack(np.where(mask))

        # Check if there are enough points to sample
        if len(mask_coords) < n_samples_pos:
            raise ValueError(f"Not enough points in mask for label {label} to sample {n_samples_pos} points.")

        # Randomly sample positive points
        sampled_pos_indices = np.random.choice(len(mask_coords), size=n_samples_pos, replace=False)
        positive_prompts_coords = mask_coords[sampled_pos_indices]

        # Define the vicinity region (expand the bounding box by `vicinity` pixels)
        y_min, x_min = mask_coords.min(axis=0)
        y_max, x_max = mask_coords.max(axis=0)
        y_min_vicinity = max(0, y_min - vicinity)
        x_min_vicinity = max(0, x_min - vicinity)
        y_max_vicinity = min(masks[0].shape[0], y_max + vicinity)
        x_max_vicinity = min(masks[0].shape[1], x_max + vicinity)

        # Get all coordinates within the vicinity region
        vicinity_mask = np.zeros_like(mask, dtype=bool)
        vicinity_mask[y_min_vicinity:y_max_vicinity + 1, x_min_vicinity:x_max_vicinity + 1] = True

        # Ensure negative points are strictly outside the mask
        negative_mask = vicinity_mask & ~mask

        # Remove points that are too close to the mask border
        mask_border = np.zeros_like(mask, dtype=bool)
        mask_border[max(0, y_min - min_distance):min(mask.shape[0], y_max + min_distance + 1),
                    max(0, x_min - min_distance):min(mask.shape[1], x_max + min_distance + 1)] = True
        negative_mask = negative_mask & ~mask_border

        # Get all valid negative coordinates
        outside_coords = np.column_stack(np.where(negative_mask))

        # Check if there are enough points to sample
        if len(outside_coords) < n_samples_neg:
            raise ValueError(f"Not enough points outside mask for label {label} to sample {n_samples_neg} points.")

        # Randomly sample negative points
        sampled_neg_indices = np.random.choice(len(outside_coords), size=n_samples_neg, replace=False)
        negative_prompts_coords = outside_coords[sampled_neg_indices]

        # Create a Prompt object for the current label
        prompt = Prompt(
            original_label=label,
            centroid=centroid,
            positive_prompts_coords=positive_prompts_coords,
            negative_prompts_coords=negative_prompts_coords
        )
        prompts.append(prompt)

    return prompts

In [22]:
prompts = sample_positive_and_negative_prompts(masks[0], centroids_df, n_samples_pos=4, n_samples_neg=3, vicinity=10)

ValueError: Not enough points in mask for label 1 to sample 4 points.

In [None]:
from collections import OrderedDict

class SAM2VideoPredictorWrapper(SAM2VideoPredictor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.to(self.device)

    @classmethod
    def from_pretrained(cls, model_id, **kwargs) -> 'SAM2VideoPredictorWrapper':
        """Load a pretrained model."""
        from sam2.build_sam import build_sam2_video_predictor_hf

        sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
        wrapper_instance = cls(
            image_encoder=sam_model.image_encoder,
            memory_attention=sam_model.memory_attention, 
            memory_encoder=sam_model.memory_encoder
        )
        wrapper_instance.__dict__.update(sam_model.__dict__)  # Copy all attributes from the base model
        wrapper_instance.to(wrapper_instance.device)  # Move to the correct device
        return wrapper_instance

    @torch.inference_mode()
    def init_state(
        self,
        # video_path,
        video_array,
        offload_video_to_cpu=False,
        offload_state_to_cpu=False,
        # async_loading_frames=False,
    ):
        """Initialize an inference state."""
        compute_device = self.device  # device of the model
        # images, video_height, video_width = load_video_frames(
        #     video_path=video_path,
        #     image_size=self.image_size,
        #     offload_video_to_cpu=offload_video_to_cpu,
        #     async_loading_frames=async_loading_frames,
        #     compute_device=compute_device,
        # )
        if len(video_array.shape) == 3:
            # add fake RGB channels by repeating the image 3 times at axis 1
            video_array = np.repeat(video_array[:, :, :, np.newaxis], 3, axis=3)
        video_array = video_array.swapaxes(3, 1)
        print(f"Input video shape: {video_array.shape}")
        video_height, video_width = video_array.shape[-2:]
        images = torch.from_numpy(video_array).to(compute_device)
        images = F.interpolate(images, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False)
        print(f"Input image shape: {images.shape}")
        inference_state = {}
        inference_state["images"] = images
        inference_state["num_frames"] = len(images)
        # whether to offload the video frames to CPU memory
        # turning on this option saves the GPU memory with only a very small overhead
        inference_state["offload_video_to_cpu"] = offload_video_to_cpu
        # whether to offload the inference state to CPU memory
        # turning on this option saves the GPU memory at the cost of a lower tracking fps
        # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
        # and from 24 to 21 when tracking two objects)
        inference_state["offload_state_to_cpu"] = offload_state_to_cpu
        # the original video height and width, used for resizing final output scores
        inference_state["video_height"] = video_height
        inference_state["video_width"] = video_width
        inference_state["device"] = compute_device
        if offload_state_to_cpu:
            inference_state["storage_device"] = torch.device("cpu")
        else:
            inference_state["storage_device"] = compute_device
        # inputs on each frame
        inference_state["point_inputs_per_obj"] = {}
        inference_state["mask_inputs_per_obj"] = {}
        # visual features on a small number of recently visited frames for quick interactions
        inference_state["cached_features"] = {}
        # values that don't change across frames (so we only need to hold one copy of them)
        inference_state["constants"] = {}
        # mapping between client-side object id and model-side object index
        inference_state["obj_id_to_idx"] = OrderedDict()
        inference_state["obj_idx_to_id"] = OrderedDict()
        inference_state["obj_ids"] = []
        # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
        inference_state["output_dict_per_obj"] = {}
        # A temporary storage to hold new outputs when user interact with a frame
        # to add clicks or mask (it's merged into "output_dict" before propagation starts)
        inference_state["temp_output_dict_per_obj"] = {}
        # Frames that already holds consolidated outputs from click or mask inputs
        # (we directly use their consolidated outputs during tracking)
        # metadata for each tracking frame (e.g. which direction it's tracked)
        inference_state["frames_tracked_per_obj"] = {}
        # Warm up the visual backbone and cache the image feature on frame 0
        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
        return inference_state

In [None]:
predictor = SAM2VideoPredictorWrapper.from_pretrained("facebook/sam2.1-hiera-base-plus", device="cuda" if torch.cuda.is_available() else "cpu")

In [None]:
inference_state = predictor.init_state(example_image)

In [None]:
for p in prompts:
    tracker_id = p.original_label
    points, labels = p.get_in_sam_format()
    print(f"Adding {points.shape[0]} points for tracker {tracker_id}.")
    _, object_ids, mask_logits = predictor.add_new_points(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=tracker_id,
        points=points,
        labels=labels,
    )

In [None]:
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }


In [None]:
masks_propagated = np.zeros_like(masks, dtype=np.uint8)
for t, frame in enumerate(masks_propagated):
    for obj_id, mask in video_segments[t].items():
        mask = mask.squeeze().swapaxes(0, 1)
        masks_propagated[t][mask] = obj_id

In [None]:
import plotly.graph_objects as go
import numpy as np

# Assuming `example_image` is the original image and `masks_propagated` is the propagated masks
time_steps = masks_propagated.shape[0]

# Create a figure
fig = go.Figure()

# Add traces for the original image and the propagated masks for each time step
for t in range(time_steps):
    # Add the original image
    fig.add_trace(go.Heatmap(
        z=example_image[t],
        colorscale='Gray',
        showscale=False,
        visible=(t == 0),  # Only the first frame is visible initially
        name=f"Original Image {t}"
    ))
    # Add the propagated masks
    fig.add_trace(go.Heatmap(
        z=masks_propagated[t],
        colorscale='Jet',
        showscale=False,
        visible=(t == 0),  # Only the first frame is visible initially
        name=f"Propagated Mask {t}"
    ))

# Update layout for the slider
fig.update_layout(
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 500, "redraw": True}, "fromcurrent": True}],
                    "label": "Play",
                    "method": "animate"
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate", "transition": {"duration": 0}}],
                    "label": "Pause",
                    "method": "animate"
                }
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top"
        }
    ],
    sliders=[
        {
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 20},
                "prefix": "Time: ",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 300, "easing": "cubic-in-out"},
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": [
                {
                    "args": [[f"frame{t}"], {"frame": {"duration": 300, "redraw": True}, "mode": "immediate", "transition": {"duration": 300}}],
                    "label": str(t),
                    "method": "animate"
                } for t in range(time_steps)
            ]
        }
    ]
)

# Add frames for animation
frames = []
for t in range(time_steps):
    frames.append(go.Frame(
        data=[
            go.Heatmap(z=example_image[t], colorscale='Gray', showscale=False),
            go.Heatmap(z=masks_propagated[t], colorscale='Jet', showscale=False)
        ],
        name=f"frame{t}"
    ))
fig.frames = frames

# Show the figure
fig.show()