#### Imports

In [19]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter

In [2]:
sys.path.append('..')

In [3]:
from src.utils.image_utils import load_czi_images
from cell_tracking import get_masks_paths
from src.utils.plot_utils import show_3d_segmentation_overlay
from src.track import get_cell_centers, centers_array_to_label_position_map
from src.utils.image_utils import enhance_cell_image_contrast



Welcome to CellposeSAM, cellpose v
cellpose version: 	4.0.5.dev16+g79b0fcb 
platform:       	linux 
python version: 	3.12.9 
torch version:  	2.7.1+cu126! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow. 
We encourage users to use GPU/MPS if available. 




#### Functions

In [4]:
def create_gif_from_figures(figures, output_path='animation.gif', fps=5, titles=None):
    """
    Create a GIF from a list of matplotlib figures
    
    Parameters:
    - figures: list of matplotlib figure objects
    - output_path: path to save the GIF
    - fps: frames per second
    - titles: optional list of titles for each frame
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    import io
    from PIL import Image
    
    frames = []
    
    for i, fig in enumerate(figures):
        # Add title if provided
        if titles and i < len(titles):
            fig.suptitle(titles[i], fontsize=16)
        
        # Convert figure to image
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = io.BytesIO()
        canvas.print_png(buf)
        buf.seek(0)
        
        # Convert to PIL Image
        img = Image.open(buf)
        frames.append(img.copy())
        buf.close()
    
    # Create and save the GIF
    if frames:
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000/fps),  # Convert fps to milliseconds
            loop=0
        )
        print(f"GIF saved to {output_path}")
        
        # Clean up figures to free memory
        for fig in figures:
            plt.close(fig)
    else:
        print("No frames were created")

In [5]:
def create_trajectory_gif(positions, output_path='trajectory.gif', fps=2, 
                         figsize=(10, 8), point_size=50, line_width=2):
    """
    Create a GIF animation of a 3D trajectory where points are added progressively
    
    Parameters:
    - positions: numpy array of shape (n_timepoints, 3) containing x, y, z coordinates
    - output_path: path to save the GIF
    - fps: frames per second
    - figsize: figure size tuple
    - point_size: size of the markers
    - line_width: width of the connecting lines
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    import io
    from PIL import Image
    
    frames = []
    
    # Get the full range for consistent axis limits
    x_min, x_max = positions[:, 0].min(), positions[:, 0].max()
    y_min, y_max = positions[:, 1].min(), positions[:, 1].max()
    z_min, z_max = positions[:, 2].min(), positions[:, 2].max()
    
    # Add some padding
    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1
    z_padding = (z_max - z_min) * 0.1
    
    for frame_idx in range(len(positions)):
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection='3d')
        
        # Get points up to current frame
        current_positions = positions[:frame_idx + 1]
        
        if len(current_positions) > 0:
            # Plot the trajectory line up to current point
            if len(current_positions) > 1:
                ax.plot(current_positions[:, 0], 
                       current_positions[:, 1], 
                       current_positions[:, 2], 
                       color='blue', linewidth=line_width, alpha=0.7)
            
            # Plot all points up to current frame
            ax.scatter(current_positions[:, 0], 
                      current_positions[:, 1], 
                      current_positions[:, 2], 
                      c=range(len(current_positions)), 
                      cmap='viridis', s=point_size, alpha=0.8)
            
            # Highlight the current (newest) point
            current_point = current_positions[-1]
            ax.scatter(current_point[0], current_point[1], current_point[2], 
                      c='red', s=point_size*1.5, marker='o', alpha=1.0)
        
        # Set consistent axis limits
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)
        ax.set_zlim(z_min - z_padding, z_max + z_padding)
        
        # Labels and title
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.set_zlabel('Z Position')
        ax.set_title(f'Cell Trajectory - Time Point {frame_idx + 1}/{len(positions)}')
        
        # Convert figure to image
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = io.BytesIO()
        canvas.print_png(buf)
        buf.seek(0)
        
        # Convert to PIL Image
        img = Image.open(buf)
        frames.append(img.copy())
        
        # Clean up
        plt.close(fig)
        buf.close()
    
    # Create and save the GIF
    if frames:
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000/fps),
            loop=0
        )
        print(f"Trajectory GIF saved to {output_path}")
    else:
        print("No frames were created")

In [6]:
def show_3d_segmentation_overlay_with_highlight(z_stack, masks, highlight_label, 
                                               base_color=[0, 255, 0], highlight_color=[255, 0, 0],
                                               save_path=None, return_fig=False):
    """
    Show 3D segmentation overlay with all labels in base color and a specific label highlighted.
    
    Args:
        z_stack: 3D numpy array of images (z, h, w)
        masks: 3D numpy array of segmentation masks (z, h, w)
        highlight_label: integer label to highlight
        base_color: RGB color for all other labels [R, G, B] (default: green)
        highlight_color: RGB color for highlighted label [R, G, B] (default: red)
        save_path: path to save the figure (optional)
        return_fig: if True, return the figure object instead of showing
    
    Returns:
        fig object if return_fig=True, otherwise None
    """
    frames = []
    for i in range(z_stack.shape[0]):
        img = z_stack[i]
        img = enhance_cell_image_contrast(img)
        maski = masks[i]
        
        # Create custom overlay with highlight
        overlay = create_highlight_overlay(img, maski, highlight_label, base_color, highlight_color)
        frames.append(overlay)

    fig, ax = plt.subplots(1, len(frames), figsize=(20, 10))
    for i, frame in enumerate(frames):
        ax[i].imshow(frame)
        ax[i].axis('off')
        ax[i].set_title(f'Slice {i+1}')
    plt.tight_layout()
    
    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight')
        print(f"Saved highlighted overlay images to {save_path}")
    
    if return_fig:
        return fig
    else:
        plt.show()
        plt.close(fig)

def create_highlight_overlay(image, mask, highlight_label, base_color, highlight_color):
    """
    Create an overlay with all labels in base color and highlighted label in different color.
    
    Args:
        image: 2D grayscale image
        mask: 2D segmentation mask
        highlight_label: label to highlight
        base_color: RGB color for all other labels
        highlight_color: RGB color for highlighted label
    
    Returns:
        RGB overlay image
    """
    # Convert grayscale to RGB
    if len(image.shape) == 2:
        rgb_image = np.stack([image, image, image], axis=-1)
    else:
        rgb_image = image.copy()
    
    # Normalize to 0-255 if needed
    if rgb_image.max() <= 1.0:
        rgb_image = (rgb_image * 255).astype(np.uint8)
    else:
        rgb_image = rgb_image.astype(np.uint8)
    
    overlay = rgb_image.copy()
    
    # Create masks for all labels and highlighted label
    all_labels_mask = mask > 0
    highlight_mask = mask == highlight_label
    other_labels_mask = all_labels_mask & ~highlight_mask
    
    # Apply base color to all other labels
    if np.any(other_labels_mask):
        overlay[other_labels_mask] = base_color
    
    # Apply highlight color to the specific label
    if np.any(highlight_mask):
        overlay[highlight_mask] = highlight_color
    
    return overlay

In [20]:
def show_3d_segmentation_overlay_with_highlight2(z_stack, masks, highlight_label, 
                                               base_color=[0, 255, 0], highlight_color=[255, 0, 0],
                                               save_path=None, return_fig=False, color_neighbors=True):
    """
    Show 3D segmentation overlay with all labels in base color and a specific label highlighted.
    Optionally colors neighboring cells with different colors based on angular position.
    
    Args:
        z_stack: 3D numpy array of images (z, h, w)
        masks: 3D numpy array of segmentation masks (z, h, w)
        highlight_label: integer label to highlight
        base_color: RGB color for all other labels [R, G, B] (default: green)
        highlight_color: RGB color for highlighted label [R, G, B] (default: red)
        save_path: path to save the figure (optional)
        return_fig: if True, return the figure object instead of showing
        color_neighbors: if True, color neighboring cells with angular-based colors
    
    Returns:
        fig object if return_fig=True, otherwise None
    """
    frames = []
    for i in range(z_stack.shape[0]):
        img = z_stack[i]
        img = enhance_cell_image_contrast(img)
        maski = masks[i]
        
        # Create custom overlay with highlight and neighbor coloring
        overlay = create_highlight_overlay_with_neighbors(img, maski, highlight_label, 
                                                        base_color, highlight_color, color_neighbors)
        frames.append(overlay)

    fig, ax = plt.subplots(1, len(frames), figsize=(20, 10))
    for i, frame in enumerate(frames):
        ax[i].imshow(frame)
        ax[i].axis('off')
        ax[i].set_title(f'Slice {i+1}')
    plt.tight_layout()
    
    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight')
        print(f"Saved highlighted overlay images to {save_path}")
    
    if return_fig:
        return fig
    else:
        plt.show()
        plt.close(fig)

def create_highlight_overlay_with_neighbors(image, mask, highlight_label, base_color, highlight_color, color_neighbors=True):
    """
    Create an overlay with all labels in base color, highlighted label in different color,
    and neighboring cells colored based on their angular position relative to the highlighted cell.
    
    Args:
        image: 2D grayscale image
        mask: 2D segmentation mask
        highlight_label: label to highlight
        base_color: RGB color for all other labels
        highlight_color: RGB color for highlighted label
        color_neighbors: if True, color neighboring cells with angular-based colors
    
    Returns:
        RGB overlay image
    """
    # Convert grayscale to RGB
    if len(image.shape) == 2:
        rgb_image = np.stack([image, image, image], axis=-1)
    else:
        rgb_image = image.copy()
    
    # Normalize to 0-255 if needed
    if rgb_image.max() <= 1.0:
        rgb_image = (rgb_image * 255).astype(np.uint8)
    else:
        rgb_image = rgb_image.astype(np.uint8)
    
    overlay = rgb_image.copy()
    
    # Create masks for all labels and highlighted label
    all_labels_mask = mask > 0
    highlight_mask = mask == highlight_label
    
    if color_neighbors and np.any(highlight_mask):
        # Get center of highlighted cell
        highlight_center = get_label_center(mask, highlight_label)
        
        if highlight_center is not None:
            # Find neighboring cells and color them based on angular position
            unique_labels = np.unique(mask[mask > 0])
            neighbor_colors = get_angular_colors(mask, unique_labels, highlight_label, highlight_center)
            
            # Apply colors to all cells
            for label in unique_labels:
                if label == highlight_label:
                    continue
                label_mask = mask == label
                if np.any(label_mask):
                    if label in neighbor_colors:
                        overlay[label_mask] = neighbor_colors[label]
                    else:
                        overlay[label_mask] = base_color
        else:
            # Fallback to base color for all other labels
            other_labels_mask = all_labels_mask & ~highlight_mask
            if np.any(other_labels_mask):
                overlay[other_labels_mask] = base_color
    else:
        # Apply base color to all other labels
        other_labels_mask = all_labels_mask & ~highlight_mask
        if np.any(other_labels_mask):
            overlay[other_labels_mask] = base_color
    
    # Apply highlight color to the specific label
    if np.any(highlight_mask):
        overlay[highlight_mask] = highlight_color
    
    return overlay

def get_label_center(mask, label):
    """Get the center of mass for a specific label in the mask."""
    label_mask = mask == label
    if not np.any(label_mask):
        return None
    
    coords = np.where(label_mask)
    center_y = np.mean(coords[0])
    center_x = np.mean(coords[1])
    return (center_y, center_x)

def get_angular_colors(mask, labels, highlight_label, highlight_center, proximity_threshold=100):
    """
    Get colors for labels based on their angular position relative to the highlighted label.
    Only colors cells within a certain proximity threshold.
    
    Args:
        mask: 2D segmentation mask
        labels: unique labels in the mask
        highlight_label: the central reference label
        highlight_center: (y, x) coordinates of the highlighted cell center
        proximity_threshold: maximum distance to consider for coloring
    
    Returns:
        Dictionary mapping label to RGB color
    """
    import colorsys
    
    neighbor_colors = {}
    
    for label in labels:
        if label == highlight_label:
            continue
            
        center = get_label_center(mask, label)
        if center is None:
            continue
        
        # Calculate distance and angle relative to highlighted cell
        dy = center[0] - highlight_center[0]
        dx = center[1] - highlight_center[1]
        distance = np.sqrt(dx**2 + dy**2)
        
        # Only color nearby cells
        if distance <= proximity_threshold:
            # Calculate angle (0 to 2π)
            angle = np.arctan2(dy, dx)
            if angle < 0:
                angle += 2 * np.pi
            
            # Convert angle to hue (0 to 1)
            hue = angle / (2 * np.pi)
            
            # Use high saturation and brightness for vivid colors
            saturation = 0.8
            brightness = 0.9
            
            # Convert HSV to RGB
            rgb = colorsys.hsv_to_rgb(hue, saturation, brightness)
            rgb_255 = [int(c * 255) for c in rgb]
            
            neighbor_colors[label] = rgb_255
    
    return neighbor_colors

#### Inputs

In [7]:
czi_file_path = '/home/dafei/data/MS2/gRNA2_12.03.25-st-13-II---.czi'
seg_maps_dir = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/masks'

In [8]:
image_data = load_czi_images(czi_file_path)

Successfully loaded /home/dafei/data/MS2/gRNA2_12.03.25-st-13-II---.czi
data shape: (1, 80, 2, 9, 1024, 1024, 1)


In [9]:
masks_paths = get_masks_paths(seg_maps_dir)

In [10]:
tracklets_path = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/masks/tracklets_hungarian_algorithm_distance_metric.json'

In [11]:
with open(tracklets_path, 'r') as f:
    tracklets = json.load(f)

In [None]:
for id in tracklets.keys():
    cell_labels = tracklets[id]
    position = []
    figures = []
    valid_timepoints = [t for t in range(len(cell_labels)) if cell_labels[t] != -1]
    for t in valid_timepoints:
        z_stack_t =image_data[0, t, 1, :, :, :, 0]
        mask_t = np.load(masks_paths[t])['masks']
        centers = get_cell_centers(mask_t)
        labels_to_pos = centers_array_to_label_position_map(centers)
        position.append(labels_to_pos[cell_labels[t]])
        cell_mask_t = (mask_t == cell_labels[t]).astype(np.uint8)
        fig = show_3d_segmentation_overlay_with_highlight2(z_stack_t, mask_t,highlight_label=cell_labels[t],
                                                            return_fig=True)
        figures.append(fig)
    output_path = os.path.join(seg_maps_dir, f'3d_tracking_id_{id}.gif')
    create_gif_from_figures(figures, output_path=f'3d_tracking_id_{id}.gif', fps=1, 
                       titles=[f'Time {i+1}' for i in valid_timepoints])
    position_array = np.array(position)
    create_trajectory_gif(position_array, output_path=f'cell_trajectory_id_{id}.gif', 
                     fps=1, figsize=(12, 9))

In [16]:
create_gif_from_figures(figures, output_path='3d_tracking_id_59.gif', fps=1, 
                       titles=[f'Time {i+1}' for i in range(len(figures))])

GIF saved to 3d_tracking_id_59.gif


In [17]:
position_array = np.array(position)

In [18]:
# Usage with your data:
create_trajectory_gif(position_array, output_path='cell_trajectory_animated_id_59_hungarian_matching.gif', 
                     fps=1, figsize=(12, 9))

Trajectory GIF saved to cell_trajectory_animated_id_59_hungarian_matching.gif


In [None]:
t = 7
z_stack_t =image_data[0, t, 1, :, :, :, 0]
mask_t = np.load(masks_paths[t])['masks']

In [None]:
show_3d_segmentation_overlay_with_highlight(
    z_stack_t, mask_t, 
    highlight_label=85,
    base_color=[0, 255, 0],      # green for other labels
    highlight_color=[255, 0, 0],  # red for highlighted label
    save_path="highlighted_overlay.png"
)

In [None]:
z_stack_t_plus1 =image_data[0, t+1, 1, :, :, :, 0]
mask_t_plus1 = np.load(masks_paths[t+1])['masks']

In [None]:
show_3d_segmentation_overlay_with_highlight(
    z_stack_t_plus1, mask_t_plus1, 
    highlight_label=87,
    base_color=[0, 255, 0],      # green for other labels
    highlight_color=[255, 0, 0],  # red for highlighted label
    save_path="highlighted_overlay.png"
)