#### Import

In [None]:
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io


In [None]:
sys.path.append("..")

In [None]:
from src.model import MS2

#### Functions

In [None]:
def create_gif_from_overlay_sequence(base_images, overlay_images, gif_path, 
                                   alpha=0.5, base_cmap='viridis', overlay_cmap='hot',
                                   duration=500, figsize=(8, 8), titles=None):
    """
    Create a GIF from a sequence of overlay images using matplotlib.
    
    Args:
        base_images (list): List of base images (numpy arrays)
        overlay_images (list): List of overlay images (numpy arrays)
        gif_path (str): Output path for the GIF file
        alpha (float): Transparency of overlay
        base_cmap (str): Colormap for base images
        overlay_cmap (str): Colormap for overlay images
        duration (int): Duration between frames in milliseconds
        figsize (tuple): Figure size for each frame
        titles (list): Optional list of titles for each frame
    
    Returns:
        str: Path to the created GIF file
    """
    frames = []
    
    for i, (base_img, overlay_img) in enumerate(zip(base_images, overlay_images)):
        # Create matplotlib figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot base image
        if len(base_img.shape) == 3:
            ax.imshow(base_img)
        else:
            ax.imshow(base_img, cmap=base_cmap)
        
        # Plot overlay
        if len(overlay_img.shape) == 3:
            ax.imshow(overlay_img, alpha=alpha)
        else:
            ax.imshow(overlay_img, cmap=overlay_cmap, alpha=alpha)
        
        # Set title if provided
        if titles and i < len(titles):
            ax.set_title(titles[i], fontsize=12)
        
        ax.axis('off')
        
        # Convert matplotlib figure to PIL Image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        buf.seek(0)
        frame = Image.open(buf)
        frames.append(frame.copy())
        
        plt.close(fig)
        buf.close()
    
    # Create GIF
    frames[0].save(
        gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=duration,
        loop=0,
        optimize=True
    )
    
    print(f"Overlay GIF saved to: {gif_path}")
    return gif_path

#### Input

In [None]:
czi_file_path = '/home/dafei/data/MS2/New-03_I.czi'
device = torch.device('cuda:0')

In [None]:
ms2 = MS2(czi_path = czi_file_path, device=device)

In [None]:
tracked_cells = ms2.cell_tracking_v2(z=0)

In [None]:
base_images = []
overlay_images = []
titles = []
cell_id = 128
for i in range(tracked_cells[cell_id].shape[0]):
    fig, ax = plt.subplots(figsize=(10, 5))
    cell_mask = tracked_cells[cell_id][i]
    cells_image = ms2.image_data[0, 0, ms2.microscope_channels[1], i, :, :, 0]
    base_images.append(cells_image)
    overlay_images.append(cell_mask)
    titles.append(f'Cell id : {cell_id}, z:{0} t: {i}')

create_gif_from_overlay_sequence(
    base_images=base_images,
    overlay_images=overlay_images,
    gif_path='overlay_sequence_128.gif',
    alpha=0.5,
    titles=titles)