#### Imports

In [1]:
import os
import sys
import json
import tifffile
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter

In [2]:
sys.path.append('..')

In [3]:
from src.utils.image_utils import load_czi_images
from cell_tracking import get_masks_paths
from src.utils.plot_utils import show_3d_segmentation_overlay
from src.track import get_cell_centers, centers_array_to_label_position_map
from src.utils.image_utils import enhance_cell_image_contrast



Welcome to CellposeSAM, cellpose v
cellpose version: 	4.0.5.dev16+g79b0fcb 
platform:       	linux 
python version: 	3.12.9 
torch version:  	2.7.1+cu126! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow. 
We encourage users to use GPU/MPS if available. 




#### Functions

Gif functions

In [4]:
def create_gif_from_figures(figures, output_path='animation.gif', fps=5, titles=None):
    """
    Create a GIF from a list of matplotlib figures
    
    Parameters:
    - figures: list of matplotlib figure objects
    - output_path: path to save the GIF
    - fps: frames per second
    - titles: optional list of titles for each frame
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    import io
    from PIL import Image
    
    frames = []
    
    for i, fig in enumerate(figures):
        # Add title if provided
        if titles and i < len(titles):
             # Adjust the layout to make room for the title
            fig.subplots_adjust(top=0.9)  # Leave space at the top for title
            fig.suptitle(titles[i], fontsize=16, y=0.95)  # Position title higher
        
        # Convert figure to image
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = io.BytesIO()
        canvas.print_png(buf)
        buf.seek(0)
        
        # Convert to PIL Image
        img = Image.open(buf)
        frames.append(img.copy())
        buf.close()
    
    # Create and save the GIF
    if frames:
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000/fps),  # Convert fps to milliseconds
            loop=0
        )
        print(f"GIF saved to {output_path}")
        
        # Clean up figures to free memory
        for fig in figures:
            plt.close(fig)
    else:
        print("No frames were created")

In [5]:
def create_trajectory_gif(positions, output_path='trajectory.gif', fps=2, 
                         figsize=(10, 8), point_size=50, line_width=2):
    """
    Create a GIF animation of a 3D trajectory where points are added progressively
    
    Parameters:
    - positions: numpy array of shape (n_timepoints, 3) containing x, y, z coordinates
    - output_path: path to save the GIF
    - fps: frames per second
    - figsize: figure size tuple
    - point_size: size of the markers
    - line_width: width of the connecting lines
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    import io
    from PIL import Image
    
    frames = []
    
    # Get the full range for consistent axis limits
    x_min, x_max = positions[:, 0].min(), positions[:, 0].max()
    y_min, y_max = positions[:, 1].min(), positions[:, 1].max()
    z_min, z_max = positions[:, 2].min(), positions[:, 2].max()
    
    # Add some padding
    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1
    z_padding = (z_max - z_min) * 0.1
    
    for frame_idx in range(len(positions)):
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection='3d')
        
        # Get points up to current frame
        current_positions = positions[:frame_idx + 1]
        
        if len(current_positions) > 0:
            # Plot the trajectory line up to current point
            if len(current_positions) > 1:
                ax.plot(current_positions[:, 0], 
                       current_positions[:, 1], 
                       current_positions[:, 2], 
                       color='blue', linewidth=line_width, alpha=0.7)
            
            # Plot all points up to current frame
            ax.scatter(current_positions[:, 0], 
                      current_positions[:, 1], 
                      current_positions[:, 2], 
                      c=range(len(current_positions)), 
                      cmap='viridis', s=point_size, alpha=0.8)
            
            # Highlight the current (newest) point
            current_point = current_positions[-1]
            ax.scatter(current_point[0], current_point[1], current_point[2], 
                      c='red', s=point_size*1.5, marker='o', alpha=1.0)
        
        # Set consistent axis limits
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)
        ax.set_zlim(z_min - z_padding, z_max + z_padding)
        
        # Labels and title
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.set_zlabel('Z Position')
        ax.set_title(f'Cell Trajectory - Time Point {frame_idx + 1}/{len(positions)}')
        
        # Convert figure to image
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = io.BytesIO()
        canvas.print_png(buf)
        buf.seek(0)
        
        # Convert to PIL Image
        img = Image.open(buf)
        frames.append(img.copy())
        
        # Clean up
        plt.close(fig)
        buf.close()
    
    # Create and save the GIF
    if frames:
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000/fps),
            loop=0
        )
        print(f"Trajectory GIF saved to {output_path}")
    else:
        print("No frames were created")

In [6]:
def create_extending_plot_gif(x_data, y_data, output_path='gene_Expression.gif', 
                             fps=2, figsize=(10, 6), line_color='blue', 
                             marker_color='red', line_width=2, marker_size=50,
                             xlabel='X', ylabel='Y', title='Extending Plot',
                             grid=True, show_current_point=True):
    """
    Create a GIF animation where a plot extends/grows progressively by adding one point at a time.
    
    Parameters:
    - x_data: array-like, x-coordinates of the data points
    - y_data: array-like, y-coordinates of the data points
    - output_path: path to save the GIF
    - fps: frames per second
    - figsize: figure size tuple
    - line_color: color of the line connecting points
    - marker_color: color of the current/newest point marker
    - line_width: width of the connecting line
    - marker_size: size of the current point marker
    - xlabel, ylabel, title: plot labels
    - grid: whether to show grid
    - show_current_point: whether to highlight the current point
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    import io
    from PIL import Image
    
    # Convert to numpy arrays
    x_data = np.array(x_data)
    y_data = np.array(y_data)
    
    if len(x_data) != len(y_data):
        raise ValueError("x_data and y_data must have the same length")
    
    frames = []
    
    # Get the full range for consistent axis limits
    x_min, x_max = x_data.min(), x_data.max()
    y_min, y_max = y_data.min(), y_data.max()
    
    # Add some padding
    x_padding = (x_max - x_min) * 0.1 if x_max > x_min else 1
    y_padding = (y_max - y_min) * 0.1 if y_max > y_min else 1
    
    for i in range(1, len(x_data) + 1):
        fig, ax = plt.subplots(figsize=figsize)
        
        # Get data up to current point
        current_x = x_data[:i]
        current_y = y_data[:i]
        
        # Plot the line up to current point
        if len(current_x) > 1:
            ax.plot(current_x, current_y, color=line_color, 
                   linewidth=line_width, alpha=0.8, marker='o', 
                   markersize=4, markerfacecolor=line_color, 
                   markeredgecolor='white', markeredgewidth=0.5)
        
        # Highlight the current (newest) point
        if show_current_point and len(current_x) > 0:
            ax.scatter(current_x[-1], current_y[-1], 
                      c=marker_color, s=marker_size, 
                      marker='o', alpha=1.0, zorder=5,
                      edgecolors='white', linewidth=1)
        
        # Set consistent axis limits
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)
        
        # Labels and formatting
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_title(f'{title} - Point {i}/{len(x_data)}')
        
        if grid:
            ax.grid(True, alpha=0.3)
        
        # Add point count annotation
        ax.text(0.02, 0.98, f'Points: {i}', transform=ax.transAxes, 
               fontsize=10, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        plt.tight_layout()
        
        # Convert figure to image
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = io.BytesIO()
        canvas.print_png(buf)
        buf.seek(0)
        
        # Convert to PIL Image
        img = Image.open(buf)
        frames.append(img.copy())
        
        # Clean up
        plt.close(fig)
        buf.close()
    
    # Create and save the GIF
    if frames:
        frames[0].save(
            output_path,
            save_all=True,
            append_images=frames[1:],
            duration=int(1000/fps),
            loop=0
        )
        print(f"Extending plot GIF saved to {output_path}")
    else:
        print("No frames were created")

cell segmentation coloring

In [7]:
def show_3d_segmentation_overlay_with_unique_colors(z_stack, masks, highlight_label, 
                                                   highlight_color=[255, 0, 0], 
                                                   background_color=[0, 0, 0],
                                                   color_scheme='hsv',
                                                   save_path=None, return_fig=False,
                                                   highlight_alpha=0.8, other_alpha=0.4):
    """
    Show 3D segmentation overlay with highlighted label in specific color and all other labels 
    colored uniquely using the color_cells_with_unique_colors function.
    
    Args:
        z_stack: 3D numpy array of images (z, h, w)
        masks: 3D numpy array of segmentation masks (z, h, w)
        highlight_label: integer label to highlight with specific color
        highlight_color: RGB color for highlighted label [R, G, B] (default: red)
        background_color: RGB color for background [R, G, B] (default: black)
        color_scheme: 'hsv', 'random', or 'gradient' for other cell coloring
        save_path: path to save the figure (optional)
        return_fig: if True, return the figure object instead of showing
        highlight_alpha: transparency for highlighted cell (0-1, higher = more opaque)
        other_alpha: transparency for other cells (0-1, lower = more transparent)
    
    Returns:
        fig object if return_fig=True, otherwise None
    """
    frames = []
    for i in range(z_stack.shape[0]):
        img = z_stack[i]
        img = enhance_cell_image_contrast(img)
        maski = masks[i]
        
        # Create colored overlay with different alpha values for highlight vs other cells
        colored_overlay = create_colored_overlay_with_dominant_highlight(
            img, maski, highlight_label, highlight_color, background_color, 
            color_scheme, highlight_alpha, other_alpha)
        frames.append(colored_overlay)

    # Handle single slice case
    if len(frames) == 1:
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        ax.imshow(frames[0])
        ax.axis('off')
        ax.set_title('Slice 1')
    else:
        fig, ax = plt.subplots(1, len(frames), figsize=(4*len(frames), 8))
        for i, frame in enumerate(frames):
            ax[i].imshow(frame)
            ax[i].axis('off')
            ax[i].set_title(f'Slice {i+1}')
    
    plt.tight_layout()
    
    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight')
        print(f"Saved colored overlay images to {save_path}")
    
    if return_fig:
        return fig
    else:
        plt.show()
        plt.close(fig)

def create_colored_overlay_with_dominant_highlight(image, mask, highlight_label, highlight_color, 
                                                 background_color, color_scheme, 
                                                 highlight_alpha=0.8, other_alpha=0.4):
    """
    Create an overlay with the base image and colored segmentation masks, 
    making the highlighted label more dominant.
    
    Args:
        image: 2D grayscale image
        mask: 2D segmentation mask
        highlight_label: label to highlight with specific color
        highlight_color: RGB color for highlighted label
        background_color: RGB color for background
        color_scheme: coloring scheme for other cells
        highlight_alpha: transparency for highlighted cell (higher = more dominant)
        other_alpha: transparency for other cells (lower = less dominant)
    
    Returns:
        RGB overlay image
    """
    # Convert grayscale to RGB
    if len(image.shape) == 2:
        rgb_image = np.stack([image, image, image], axis=-1)
    else:
        rgb_image = image.copy()
    
    # Normalize to 0-255 if needed
    if rgb_image.max() <= 1.0:
        rgb_image = (rgb_image * 255).astype(np.uint8)
    else:
        rgb_image = rgb_image.astype(np.uint8)
    
    # Get colored segmentation
    colored_mask = color_cells_with_unique_colors(mask, highlight_label, 
                                                highlight_color, background_color, 
                                                color_scheme)
    
    # Start with the base image
    overlay = rgb_image.copy().astype(np.float32)
    
    # Create separate masks for highlighted cell and other cells
    highlight_mask = mask == highlight_label
    other_cells_mask = (mask > 0) & (mask != highlight_label)
    
    # Blend highlighted cell with higher opacity (more dominant)
    if np.any(highlight_mask):
        overlay[highlight_mask] = (1 - highlight_alpha) * overlay[highlight_mask] + \
                                highlight_alpha * colored_mask[highlight_mask].astype(np.float32)
    
    # Blend other cells with lower opacity (less dominant)
    if np.any(other_cells_mask):
        overlay[other_cells_mask] = (1 - other_alpha) * overlay[other_cells_mask] + \
                                  other_alpha * colored_mask[other_cells_mask].astype(np.float32)
    
    return overlay.astype(np.uint8)

def color_cells_with_unique_colors(mask, highlight_label, highlight_color=[255, 0, 0], 
                                 background_color=[0, 0, 0], color_scheme='hsv',
                                 highlight_brightness=1.0, other_brightness=0.7):
    """
    Color a specific cell label with a given color and assign unique colors to all other labels.
    Enhanced to make highlight label more dominant through brightness control.
    
    Args:
        mask: 2D or 3D segmentation mask with integer labels
        highlight_label: integer label to color with specific color
        highlight_color: RGB color for the highlighted label [R, G, B] (default: red)
        background_color: RGB color for background (label 0) [R, G, B] (default: black)
        color_scheme: 'hsv', 'random', or 'gradient' for other cell coloring
        highlight_brightness: brightness multiplier for highlighted cell (>1 = brighter)
        other_brightness: brightness multiplier for other cells (<1 = dimmer)
    
    Returns:
        RGB image array with colored cells
    """
    import colorsys
    import numpy as np
    
    # Get unique labels excluding background (0)
    unique_labels = np.unique(mask[mask > 0])
    
    # Create RGB output image
    if len(mask.shape) == 2:
        colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    else:
        colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    
    # Set background color
    background_pixels = mask == 0
    colored_mask[background_pixels] = background_color
    
    # Color the highlighted label with enhanced brightness
    if highlight_label in unique_labels:
        highlight_pixels = mask == highlight_label
        enhanced_highlight_color = [min(255, int(c * highlight_brightness)) for c in highlight_color]
        colored_mask[highlight_pixels] = enhanced_highlight_color
        # Remove highlight label from list to avoid coloring it again
        unique_labels = unique_labels[unique_labels != highlight_label]
    
    # Generate colors for remaining labels with reduced brightness
    if len(unique_labels) > 0:
        if color_scheme == 'hsv':
            colors = generate_hsv_colors(len(unique_labels), brightness_factor=other_brightness)
        elif color_scheme == 'random':
            colors = generate_random_colors(len(unique_labels), brightness_factor=other_brightness)
        elif color_scheme == 'gradient':
            colors = generate_gradient_colors(len(unique_labels), brightness_factor=other_brightness)
        else:
            colors = generate_hsv_colors(len(unique_labels), brightness_factor=other_brightness)
        
        # Assign colors to each label
        for i, label in enumerate(unique_labels):
            label_pixels = mask == label
            colored_mask[label_pixels] = colors[i]
    
    return colored_mask

def generate_hsv_colors(n_colors, brightness_factor=1.0):
    """Generate n distinct colors using HSV color space with brightness control."""
    import colorsys
    colors = []
    for i in range(n_colors):
        hue = i / n_colors
        saturation = 0.8 + (i % 3) * 0.1
        brightness = (0.7 + (i % 2) * 0.2) * brightness_factor  # Apply brightness factor
        
        rgb = colorsys.hsv_to_rgb(hue, saturation, brightness)
        rgb_255 = [min(255, int(c * 255)) for c in rgb]
        colors.append(rgb_255)
    
    return colors

def generate_random_colors(n_colors, seed=42, brightness_factor=1.0):
    """Generate n random colors with brightness control."""
    np.random.seed(seed)
    colors = []
    
    for i in range(n_colors):
        r = np.random.randint(50, 255)
        g = np.random.randint(50, 255)
        b = np.random.randint(50, 255)
        
        # Apply brightness factor
        r = min(255, int(r * brightness_factor))
        g = min(255, int(g * brightness_factor))
        b = min(255, int(b * brightness_factor))
        
        while (r + g + b) < 150:  # Ensure minimum visibility
            r = min(255, r + 20)
            g = min(255, g + 20)
            b = min(255, b + 20)
        
        colors.append([r, g, b])
    
    return colors

def generate_gradient_colors(n_colors, brightness_factor=1.0):
    """Generate n colors in a gradient with brightness control."""
    colors = []
    
    for i in range(n_colors):
        ratio = i / max(1, n_colors - 1)
        
        if ratio < 0.25:
            r = 0
            g = int(255 * (ratio / 0.25))
            b = 255
        elif ratio < 0.5:
            r = 0
            g = 255
            b = int(255 * (1 - (ratio - 0.25) / 0.25))
        elif ratio < 0.75:
            r = int(255 * ((ratio - 0.5) / 0.25))
            g = 255
            b = 0
        else:
            r = 255
            g = int(255 * (1 - (ratio - 0.75) / 0.25))
            b = 0
        
        # Apply brightness factor
        r = min(255, int(r * brightness_factor))
        g = min(255, int(g * brightness_factor))
        b = min(255, int(b * brightness_factor))
        
        colors.append([r, g, b])
    
    return colors

#### Inputs

In [9]:
czi_file_path = '/home/dafei/data/MS2/gRNA2_12.03.25-st-13-II---.czi'
seg_maps_dir = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/masks'
background_sub_ms2_channel = "/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/C1-gRNA2_12.03.25-st-13-II---_ms2_channel_background_sub.tif"

In [10]:
gifs_output_dir = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/outputs_v1_tracking/'

In [11]:
image_data = load_czi_images(czi_file_path)

Successfully loaded /home/dafei/data/MS2/gRNA2_12.03.25-st-13-II---.czi
data shape: (1, 80, 2, 9, 1024, 1024, 1)


In [12]:
masks_paths = get_masks_paths(seg_maps_dir)

In [13]:
ms2_background_sub = tifffile.imread(background_sub_ms2_channel)

In [14]:
tracklets_path = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/masks/tracklets_hungarian_algorithm_distance_metric.json'

In [15]:
with open(tracklets_path, 'r') as f:
    tracklets = json.load(f)

In [16]:
from tqdm import tqdm

In [17]:
cell_labels = tracklets[str(7)]
position = []
gene_expression = []
figures = []
valid_timepoints = [t for t in range(len(cell_labels)) if cell_labels[t] != -1]
for t in tqdm(valid_timepoints):
    z_stack_t =image_data[0, t, 1, :, :, :, 0]
    mask_t = np.load(masks_paths[t])['masks']
    centers = get_cell_centers(mask_t)
    labels_to_pos = centers_array_to_label_position_map(centers)
    position.append(labels_to_pos[cell_labels[t]])
    cell_mask_t = (mask_t == cell_labels[t]).astype(np.uint8)
    coords = np.where(cell_mask_t > 0)
    z_coords = coords[0]
    y_coords = coords[1]
    x_coords = coords[2]
    ms2_values = ms2_background_sub[t,:,:,:][cell_mask_t > 0]
    th = np.percentile(ms2_values, 99.9)
    gene_expression.append(np.sum(ms2_values[ms2_values >= th]))
    fig = show_3d_segmentation_overlay_with_unique_colors(
        z_stack_t, mask_t, 
        highlight_label=cell_labels[t],
        highlight_color=[255, 0, 0],
        color_scheme='hsv',
        return_fig=True
    )
    figures.append(fig)
output_path1 = os.path.join(f'3d_tracking_id_{7}.gif')
create_gif_from_figures(figures, output_path=output_path1, fps=1, 
                       titles=[f'Time {i+1}' for i in valid_timepoints])
position_array = np.array(position)
output_path2 = os.path.join(f'cell_trajectory_id_{7}.gif')
create_trajectory_gif(position_array, output_path=output_path2, 
                     fps=1, figsize=(12, 9))
create_extending_plot_gif(
    x_data=valid_timepoints,
    y_data=gene_expression,
    output_path='gene_expression.gif',
    fps=1,
    figsize=(10, 6),
    line_color='blue',
    title='Gene expression for cell ID 7',
    xlabel='Time (frames)',
    ylabel= 'A.U')

  fig, ax = plt.subplots(1, len(frames), figsize=(4*len(frames), 8))
100%|██████████| 38/38 [05:13<00:00,  8.26s/it]


GIF saved to 3d_tracking_id_7.gif
Trajectory GIF saved to cell_trajectory_id_7.gif
Extending plot GIF saved to gene_expression.gif
