#### Import

In [None]:
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io

from cellpose.utils import masks_to_outlines
from cellpose.plot import image_to_rgb


In [None]:
sys.path.append("..")

In [None]:
from src.model import MS2

#### Functions

In [None]:
def create_gif_from_overlay_sequence(base_images, gif_path,
                                   duration=500, figsize=(8, 8), titles=None):
    """
    Create a GIF from a sequence of overlay images using matplotlib.
    
    Args:
        base_images (list): List of base images (numpy arrays)
        overlay_images (list): List of overlay images (numpy arrays)
        gif_path (str): Output path for the GIF file
        alpha (float): Transparency of overlay
        base_cmap (str): Colormap for base images
        overlay_cmap (str): Colormap for overlay images
        duration (int): Duration between frames in milliseconds
        figsize (tuple): Figure size for each frame
        titles (list): Optional list of titles for each frame
    
    Returns:
        str: Path to the created GIF file
    """
    frames = []
    
    for i, base_img in enumerate(base_images):
        # Create matplotlib figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot base image
        if len(base_img.shape) == 3:
            ax.imshow(base_img)
        else:
            ax.imshow(base_img)
        
        # Set title if provided
        if titles and i < len(titles):
            ax.set_title(titles[i], fontsize=12)
        
        ax.axis('off')
        
        # Convert matplotlib figure to PIL Image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        buf.seek(0)
        frame = Image.open(buf)
        frames.append(frame.copy())
        
        plt.close(fig)
        buf.close()
    
    # Create GIF
    frames[0].save(
        gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=duration,
        loop=0,
        optimize=True
    )
    
    print(f"Overlay GIF saved to: {gif_path}")
    return gif_path

In [None]:
def enhance_cell_image_contrast(image):
    if image.shape[0] < 4:
        image = np.transpose(image, (1, 2, 0))
    if image.shape[-1] < 3 or image.ndim < 3:
        image = image_to_rgb(image, channels=[0, 0])
    else:
        if image.max() <= 50.0:
            image = np.uint8(np.clip(image, 0, 1) * 255)
    return image

def draw_cell_outline_on_image(mask, image):
    if np.sum(mask) == 0:
        return image
    outlines = masks_to_outlines(mask)
    outX, outY = np.nonzero(outlines)
    imgout = image.copy()
    imgout[outX, outY] = np.array([255, 0, 0])  # pure red
    return imgout

#### Input

In [None]:
czi_file_path = '/home/dafei/data/MS2/New-03_I.czi'
device = torch.device('cuda:0')
z = 0

In [None]:
ms2 = MS2(czi_path = czi_file_path, device=device)

In [None]:
# instance segmentation
mask = ms2.segment_cells(z=z,t=38,plot_flag=True)

In [None]:
tracked_cells = ms2.cell_tracking(z=z)

In [None]:
images = []
titles = []
cell_id = 56
for i in range(tracked_cells[cell_id].shape[0]):
    fig, ax = plt.subplots(figsize=(10, 5))
    cell_mask = tracked_cells[cell_id][i]
    cells_image = ms2.image_data[0, 0, ms2.microscope_channels[1], i, :, :, 0]
    cells_image = enhance_cell_image_contrast(cells_image)
    cell_outline_on_img = draw_cell_outline_on_image(cell_mask, cells_image)
    images.append(cell_outline_on_img)
    titles.append(f'Cell id : {cell_id}, z:{0} t: {i}')

create_gif_from_overlay_sequence(
    base_images=images,
    gif_path='overlay_sequence_128.gif',
    alpha=0.5,
    titles=titles)