### Fine tuning SAM with OMERO data using a batch approach

### Features
- Supports multiple OMERO data types (single images, datasets, projects, plates, and screens)
- Batch processing with micro-SAM for segmentation
- Stores all annotations in OMERO as ROIs and attachments
- Uses dask for lazy loading of images for better memory management
- Supports 3D volumetric segmentation for z-stacks

### TODOs
- Store all annotations into OMERO, see: https://github.com/computational-cell-analytics/micro-sam/issues/445; in series annotator possible to add commit path with prompts, but they get overwritten
- Clean up the errors and warnings output from napari
- Improve ROI creation for 3D volumes to better represent volumetric masks in OMERO
- Work with Dask arrays directly in micro-sam

Instructions:
  - To make it easier to run with OMERO and to not expose login and passwords password is stored in .env file (see example .env_example) . Still it is not recommended to save credentials unencrypted hence a better solution will be worked on.
  - This notebook supports processing images from various OMERO container types: images, datasets, projects, plates, and screens.
  - Specify the container type in the `datatype` variable and the container ID in the `data_id` variable.
  - You can choose to segment all images in the container or select a random subset for training and validation.

In [None]:
# OMERO-related imports
import omero
from omero.gateway import BlitzGateway
import ezomero

# Scientific computing and image processing
import cv2
import imageio.v3 as imageio
import numpy as np
import pandas as pd

# File and system operations
import os
import shutil
import tempfile
import zipfile
import warnings
from tifffile import imwrite, imread
from dotenv import load_dotenv

# Dask and Zarr for lazy loading and processing
import dask
import dask.array as da
import zarr

# Micro-SAM and Napari
from napari.settings import get_settings
from micro_sam.sam_annotator import image_series_annotator, annotator_2d
from micro_sam.util import precompute_image_embeddings
import napari

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)
conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), group=os.environ.get("GROUP"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "screen", "plate", "project", "dataset", "image"
data_id = 1112
nucl_channel = 0

def print_object_details(conn, obj, datatype):
    """Print detailed information about OMERO objects"""
    print(f"\n{datatype.capitalize()} Details:")
    print(f"- Name: {obj.getName()}")
    print(f"- ID: {obj.getId()}")
    print(f"- Owner: {obj.getOwner().getFullName()}")
    print(f"- Group: {obj.getDetails().getGroup().getName()}")
    
    if datatype == "project":
        datasets = list(obj.listChildren())
        dataset_count = len(datasets)
        total_images = sum(len(list(ds.listChildren())) for ds in datasets)
        print(f"- Number of datasets: {dataset_count}")
        print(f"- Total images: {total_images}")
        
    elif datatype == "plate":
        wells = list(obj.listChildren())
        well_count = len(wells)
        print(f"- Number of wells: {well_count}")
        
    elif datatype == "dataset":
        images = list(obj.listChildren())
        image_count = len(images)
        # Get project info if dataset is in a project
        projects = obj.getParent()
        if projects:
            print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Project: None (orphaned dataset)")
        print(f"- Number of images: {image_count}")
        
    elif datatype == "image":
        size_x = obj.getSizeX()
        size_y = obj.getSizeY()
        size_z = obj.getSizeZ()
        size_c = obj.getSizeC()
        size_t = obj.getSizeT()
        # Get dataset info if image is in a dataset
        datasets = obj.getParent()
        if datasets:
            print(f"- Dataset: {datasets.getName()} (ID: {datasets.getId()})")
            # Get project info if dataset is in a project
            projects = datasets.getParent()
            if projects:
                print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Dataset: None (orphaned image)")
        print(f"- Dimensions: {size_x}x{size_y}")
        print(f"- Z-stack: {size_z}")
        print(f"- Channels: {size_c}")
        print(f"- Timepoints: {size_t}")

# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")
    
elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")
    
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")
    
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")

### Create temporary folder to store training data, this will be uploaded to OMERO later

In [None]:
output_directory = os.path.normcase(tempfile.mkdtemp())
print('Output Directory: ', output_directory)

In [None]:
def zip_directory(source_path, zarr_path, zip_file):
    """Zip a directory while handling null characters in paths."""
    for root, dirs, files in os.walk(zarr_path):
        for file in files:
            try:
                # Create paths
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, source_path)
                
                # Remove null characters while preserving the path structure
                safe_full_path = full_path.replace('\x00', '')
                safe_rel_path = rel_path.replace('\x00', '')
                
                # Add file to zip if it exists
                if os.path.exists(safe_full_path):
                    zip_file.write(safe_full_path, safe_rel_path)
            except Exception as e:
                print(f"Warning: Error processing {file}: {str(e)}")
                continue

def interleave_arrays(train_images, validate_images):
    """
    Interleave two arrays of images in the pattern: train[0], validate[0], train[1], validate[1], ...
    If arrays are of unequal length, remaining elements are appended at the end.
    """
    # Create empty list to store interleaved images
    interleaved = []
    sequence = []
    # Get the length of the longer array
    max_len = max(len(train_images), len(validate_images))
    
    # Interleave the arrays
    for i in range(max_len):
        # Add train image if available
        if i < len(train_images):
            interleaved.append(train_images[i])
            sequence.append(0)
        # Add validate image if available
        if i < len(validate_images):
            interleaved.append(validate_images[i])
            sequence.append(1)
    
    return np.array(interleaved), np.array(sequence)

def mask_to_contour(mask):
    """Converts a binary mask to a list of ROI coordinates.

    Args:
        mask (np.ndarray): binary mask

    Returns:
        list: list of ROI coordinates
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def label_to_rois(label_img, z_slice, channel, timepoint, model_type, is_volumetric=False):
    """
    Convert a 2D or 3D label image to OMERO ROI shapes
    
    Args:
        label_img (np.ndarray): 2D labeled image or 3D labeled stack
        z_slice (int or list): Z-slice index or list/range of Z indices
        channel (int): Channel index
        timepoint (int): Time point index
        model_type (str): SAM model type used
        is_volumetric (bool): Whether the label image is 3D volumetric data
    
    Returns:
        list: List of OMERO shape objects
    """
    shapes = []
    
    if is_volumetric and label_img.ndim > 2:
        # 3D volumetric data - process each z slice
        for z_index, z_plane in enumerate(label_img):
            # If z_slice is a range or list, use the actual z-index from that range
            if isinstance(z_slice, (range, list)):
                actual_z = z_slice[z_index] if z_index < len(z_slice) else z_slice[0] + z_index
            else:
                actual_z = z_slice + z_index  # Assume z_slice is the starting index
                
            print(f"Processing volumetric ROIs for z-slice {actual_z}")
            shapes.extend(process_label_plane(z_plane, actual_z, channel, timepoint, model_type))
    else:
        # 2D data - process single plane
        shapes.extend(process_label_plane(label_img, z_slice, channel, timepoint, model_type))
    
    return shapes

def process_label_plane(label_plane, z_slice, channel, timepoint, model_type):
    """Process a single 2D label plane to generate OMERO shapes"""
    shapes = []
    unique_labels = np.unique(label_plane)
    
    # Skip background (label 0)
    for label in unique_labels[1:]:
        # Create binary mask for this label
        mask = (label_plane == label).astype(np.uint8)
        
        # Get contours
        contours = mask_to_contour(mask)
        
        # Convert each contour to polygon ROI
        for contour in contours:
            contour = contour[:, 0, :]  # Reshape to (N, 2)
            # Create polygon without text parameter
            poly = ezomero.rois.Polygon(
                points=contour,  # explicitly name the points parameter
                z=z_slice,
                c=channel,
                t=timepoint,
                label=f'micro_sam.{"volumetric" if z_slice > 0 else "manual"}_instance_segmentation.{model_type}'
            )
            shapes.append(poly)
    
    return shapes

#TODO merge these functions to a source file with functions for processing OMERO data
def upload_rois_and_labels(conn, image, label_file, z_slice, channel, timepoint, model_type, is_volumetric=False):
    """
    Upload both label map and ROIs for a segmented image
    
    Args:
        conn: OMERO connection
        image: OMERO image object
        label_file: Path to the label image file
        z_slice: Z-slice index or range of indices
        channel: Channel index
        timepoint: Time point index
        model_type: SAM model type used
        is_volumetric: Whether the data is 3D volumetric
    
    Returns:
        tuple: (label_id, roi_id)
    """
    # Upload label map as attachment
    label_id = ezomero.post_file_annotation(
        conn,
        str(label_file),
        ns='microsam.labelimage',
        object_type="Image",
        object_id=image.getId(),
        description=f'SAM {"volumetric" if is_volumetric else "manual"} segmentation ({model_type})'
    )
    
    # Create ROIs from label image
    label_img = imageio.imread(label_file)
    shapes = label_to_rois(label_img, z_slice, channel, timepoint, model_type, is_volumetric)
    
    if shapes:  # Only create ROI if shapes were found
        roi_id = ezomero.post_roi(
            conn,
            image.getId(),
            shapes,
            name=f'SAM_{model_type}{"_3D" if is_volumetric else ""}',
            description=f'micro_sam.{"volumetric" if is_volumetric else "manual"}_instance_segmentation.{model_type}'
        )
    else:
        roi_id = None
        
    return label_id, roi_id


In [None]:
def get_images_from_container(conn, datatype, container_id):
    """
    Extract all images from a given OMERO container (Project, Dataset, Plate, Screen)
    
    Args:
        conn: OMERO connection
        datatype: Type of container ('project', 'dataset', 'plate', 'screen', 'image')
        container_id: ID of the container
        
    Returns:
        list: List of OMERO image objects
        str: Description of the source (for tracking)
    """
    images = []
    source_desc = ""
    
    if datatype == "image":
        image = conn.getObject("Image", container_id)
        if image is None:
            raise ValueError(f"Image with ID {container_id} not found")
        images = [image]
        source_desc = f"Image: {image.getName()} (ID: {container_id})"
    
    elif datatype == "dataset":
        dataset = conn.getObject("Dataset", container_id)
        if dataset is None:
            raise ValueError(f"Dataset with ID {container_id} not found")
        images = list(dataset.listChildren())
        source_desc = f"Dataset: {dataset.getName()} (ID: {container_id})"
    
    elif datatype == "project":
        project = conn.getObject("Project", container_id)
        if project is None:
            raise ValueError(f"Project with ID {container_id} not found")
        # Get all datasets in the project
        for dataset in project.listChildren():
            # Get all images in each dataset
            for image in dataset.listChildren():
                images.append(image)
        source_desc = f"Project: {project.getName()} (ID: {container_id})"
    
    elif datatype == "plate":
        plate = conn.getObject("Plate", container_id)
        if plate is None:
            raise ValueError(f"Plate with ID {container_id} not found")
        # Get all wells in the plate
        for well in plate.listChildren():
            # Get all images (fields) in each well
            for wellSample in well.listChildren():
                images.append(wellSample.getImage())
        source_desc = f"Plate: {plate.getName()} (ID: {container_id})"
    
    elif datatype == "screen":
        screen = conn.getObject("Screen", container_id)
        if screen is None:
            raise ValueError(f"Screen with ID {container_id} not found")
        # Get all plates in the screen
        for plate in screen.listChildren():
            # Get all wells in each plate
            for well in plate.listChildren():
                # Get all images (fields) in each well
                for wellSample in well.listChildren():
                    images.append(wellSample.getImage())
        source_desc = f"Screen: {screen.getName()} (ID: {container_id})"
    
    else:
        raise ValueError(f"Unsupported datatype: {datatype}")
    
    print(f"Found {len(images)} images from {source_desc}")
    return images, source_desc

### Dask Lazy Loading Functions for OMERO Data

In [None]:
def get_dask_image(conn, image_id, z_slice=None, timepoint=None, channel=None, three_d=False):
    """
    Get a dask array representation of an OMERO image for lazy loading
    
    Args:
        conn: OMERO connection
        image_id: ID of image to load
        z_slice: Optional specific Z slice to load (int or list)
        timepoint: Optional specific timepoint to load (int or list)
        channel: Optional specific channel to load (int or list)
        three_d: Whether to load a 3D volume (all z-slices) instead of a single slice
    
    Returns:
        dask array representation of image
    """
    image = conn.getObject("Image", image_id)
    pixels = image.getPrimaryPixels()
    
    # Get image dimensions
    size_z = image.getSizeZ()
    size_c = image.getSizeC()
    size_t = image.getSizeT()
    size_y = image.getSizeY()
    size_x = image.getSizeX()
    
    # Define specific dimensions to load if provided
    # If three_d is True, we want all z-slices, otherwise use the provided z_slice
    if three_d:
        z_range = range(size_z)  # Load all z-slices for 3D
    else:
        z_range = [z_slice] if isinstance(z_slice, int) else (range(size_z) if z_slice is None else z_slice)
    
    t_range = [timepoint] if isinstance(timepoint, int) else (range(size_t) if timepoint is None else timepoint)
    c_range = [channel] if isinstance(channel, int) else (range(size_c) if channel is None else channel)
    
    # Create empty dict to store delayed objects
    delayed_planes = {}
    
    print(f"Creating dask array for image {image_id} with lazy loading")
    print(f"Dimensions: Z={len(z_range)}, C={len(c_range)}, T={len(t_range)}, Y={size_y}, X={size_x}")
    print(f"3D mode: {three_d}")
    
    # Create lazy loading function
    @dask.delayed
    def get_plane(z, c, t):
        print(f"Loading plane: Z={z}, C={c}, T={t}")
        return pixels.getPlane(z, c, t)
    
    # Build dask arrays
    arrays = []
    for t in t_range:
        t_arrays = []
        for z in z_range:
            z_arrays = []
            for c in c_range:
                # Create a key for this plane
                key = (z, c, t)
                
                # Check if we've already created this delayed object
                if key not in delayed_planes:
                    # Create a delayed object for this plane
                    delayed_plane = get_plane(z, c, t)
                    delayed_planes[key] = delayed_plane
                else:
                    delayed_plane = delayed_planes[key]
                
                # Convert to dask array with known shape and dtype
                shape = (size_y, size_x)
                dtype = np.uint16  # Most OMERO images use 16-bit
                dask_plane = da.from_delayed(delayed_plane, shape=shape, dtype=dtype)
                z_arrays.append(dask_plane)
            if z_arrays:
                # Stack channels for this z position
                t_arrays.append(da.stack(z_arrays))
        if t_arrays:
            # Stack z-planes for this timepoint
            arrays.append(da.stack(t_arrays))
    
    if arrays:
        # Stack all timepoints
        return da.stack(arrays)
    else:
        return None

def store_annotations_in_zarr(mask_data, output_folder, image_num):
    """
    Store annotation masks in zarr format for efficient access
    
    Args:
        mask_data: Numpy array with mask data
        output_folder: Base folder to store zarr data
        image_num: Image number/identifier
        
    Returns:
        path: Path to the zarr store
    """
    # Create zarr directory if it doesn't exist
    zarr_dir = os.path.join(output_folder, "annotations")
    os.makedirs(zarr_dir, exist_ok=True)
    
    # Create zarr filename
    zarr_path = os.path.join(zarr_dir, f"annotation_{image_num:05d}.zarr")
    
    # Remove existing zarr store if it exists
    if os.path.exists(zarr_path):
        shutil.rmtree(zarr_path)
        
    # Create zarr array from mask data
    z = zarr.open(zarr_path, mode='w')
    z.create_dataset('masks', data=mask_data, chunks=(256, 256))
    
    # Return path to zarr store
    return zarr_path

def zarr_to_tiff(zarr_path, output_tiff_path):
    """
    Convert zarr store to TIFF file for OMERO upload
    
    Args:
        zarr_path: Path to zarr store
        output_tiff_path: Path to save TIFF file
        
    Returns:
        output_tiff_path: Path to saved TIFF file
    """
    # Load data from zarr
    z = zarr.open(zarr_path, mode='r')
    mask_data = z['masks'][:]
    
    # Save as TIFF
    imwrite(output_tiff_path, mask_data)
    
    return output_tiff_path

def cleanup_local_embeddings(output_folder):
    """
    Check for and clean up any existing embeddings from previous interrupted runs
    
    Args:
        output_folder: Path to the output folder containing embeddings
    """
    embed_path = os.path.join(output_folder, "embed")
    
    if os.path.exists(embed_path):
        # Look for embedding zarr directories and zip files
        for item in os.listdir(embed_path):
            item_path = os.path.join(embed_path, item)
            if os.path.isdir(item_path) and "embedding_" in item and item.endswith(".zarr"):
                print(f"Cleaning up leftover embedding directory: {item}")
                shutil.rmtree(item_path)
            elif os.path.isfile(item_path) and "embedding_" in item and item.endswith(".zip"):
                print(f"Cleaning up leftover embedding zip: {item}")
                os.remove(item_path)
    
    # Check output directory for segmentation files
    output_path = os.path.join(output_folder, "output")
    if os.path.exists(output_path):
        for item in os.listdir(output_path):
            item_path = os.path.join(output_path, item)
            if os.path.isfile(item_path) and "seg_" in item and (item.endswith(".tif") or item.endswith(".tiff")):
                print(f"Cleaning up leftover segmentation file: {item}")
                os.remove(item_path)

def process_omero_batch_with_dask(
    images_list,
    output_folder: str,
    container_type: str,
    container_id: int,
    source_desc: str,
    model_type: str = 'vit_l',
    batch_size: int = 3,
    channel: int = 0,
    timepoint: int = 0,
    z_slice: int = 0,
    z_range: list = None,
    segment_all: bool = True,
    train_n: int = 3,
    validate_n: int = 3,
    three_d: bool = False,
    resume_from_table: bool = False
):

    """
    Process OMERO images in batches for SAM segmentation using dask for lazy loading
    and zarr for temporary annotation storage
    
    Args:
        images_list: List of OMERO image objects
        output_folder: Path to store temporary files
        container_type: Type of OMERO container ('dataset', 'plate', 'project', 'screen', 'image')
        container_id: ID of the container
        source_desc: Description of the container (for display and tracking)
        model_type: SAM model type
        batch_size: Number of images to process at once
        channel: Channel to segment
        timepoint: Timepoint to process
        z_slice: Z-slice to process (used only when three_d=False)
        z_range: Range of Z-slices to process (optional, for fine control in 3D mode)
        segment_all: segment all images in the dataset or only train/validate subset
        train_n: Number of training images if not segment_all
        validate_n: Number of validation images if not segment_all
        three_d: Whether to use 3D volumetric mode
        resume_from_table: Whether to resume annotation from an existing tracking table
    
    Returns:
        tuple: (table_id, combined_images)
    """
    # Setup output directories
    output_path = os.path.join(output_folder, "output")
    embed_path = os.path.join(output_folder, "embed")
    zarr_path = os.path.join(output_folder, "zarr")
    
    # Check for and clean up any existing embeddings from interrupted runs
    cleanup_local_embeddings(output_folder)
    
    # Remove directories if they exist
    for path in [output_path, embed_path, zarr_path]:
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path)
        
    # Create or retrieve tracking DataFrame
    df = pd.DataFrame(columns=[
        "image_id", "image_name", "train", "validate", 
        "channel", "z_slice", "timepoint", "sam_model", "embed_id", "label_id", "roi_id", "is_volumetric", "processed"
    ])
    
    table_id = None
    
    # Check if we should resume from an existing table
    if resume_from_table:
        try:
            # Get existing tracking table
            existing_tables = ezomero.get_table_names(conn, container_type.capitalize(), container_id)
            if "micro_sam_training_data" in existing_tables:
                # Get the table ID and data
                table_ids = ezomero.get_table_ids(conn, container_type.capitalize(), container_id)
                for tid in table_ids:
                    table_name = ezomero.get_table_names(conn, container_type.capitalize(), container_id, tid)
                    if table_name == "micro_sam_training_data":
                        table_id = tid
                        existing_df = ezomero.get_table(conn, table_id)
                        df = existing_df
                        
                        print(f"Resuming from existing table ID: {table_id}")
                        print(f"Found {len(df)} previously processed images")
                        break
        except Exception as e:
            print(f"Error retrieving existing table: {e}. Starting fresh.")
            resume_from_table = False
    
    # Get images list (already provided as argument)
    combined_images_sequence = np.zeros(len(images_list))  # Initialize sequence array
    
    # Select images based on segment_all flag
    if segment_all:
        combined_images = images_list
        combined_images_sequence = np.zeros(len(combined_images))  # All treated as training
    else:
        # Check if we have enough images
        if len(images_list) < train_n + validate_n:
            print("Not enough images in container for training and validation")
            raise ValueError(f"Need at least {train_n + validate_n} images but found {len(images_list)}")
            
        # Select random images for training and validation
        train_indices = np.random.choice(len(images_list), train_n, replace=False)
        train_images = [images_list[i] for i in train_indices]
        
        # Get validation images from the remaining ones
        validate_candidates = [img for i, img in enumerate(images_list) if i not in train_indices]
        validate_images = np.random.choice(validate_candidates, validate_n, replace=False)
        
        # Interleave the arrays and create sequence markers
        combined_images, combined_images_sequence = interleave_arrays(train_images, validate_images)
    
    # If resuming, filter out already processed images
    if resume_from_table:
        # Get list of already processed image IDs
        processed_ids = set(df[df['processed'] == True]['image_id'].values)
        
        # Filter combined_images
        filtered_images = []
        filtered_sequence = []
        for i, img in enumerate(combined_images):
            if img.getId() not in processed_ids:
                filtered_images.append(img)
                filtered_sequence.append(combined_images_sequence[i])
        
        combined_images = filtered_images
        combined_images_sequence = np.array(filtered_sequence)
        
        print(f"Found {len(combined_images)} remaining images to process")
    
    # Calculate total number of batches
    total_batches = (len(combined_images) + batch_size - 1) // batch_size
    
    print(f"Processing {len(combined_images)} images in {total_batches} batches")
    print(f"3D mode: {three_d}")
    
    # Process images in batches
    for batch_idx in range(total_batches):
        print(f"\nProcessing batch {batch_idx+1}/{total_batches}")
        
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(combined_images))
        batch_images = combined_images[start_idx:end_idx]
        
        # Load batch images as dask arrays for lazy loading
        # Use the global conn variable that was already connected instead of trying to get it from dataset
        images = []
        dask_images = []
        image_ids = []
        
        for image in batch_images:
            image_ids.append(image.getId())
            
            if three_d:
                # For 3D mode, we need to get all z-slices
                pixels = image.getPrimaryPixels()
                # Get the 3D stack directly instead of just one plane
                if z_range is not None:
                    # Get specific z-range if provided
                    img_3d = np.stack([pixels.getPlane(z, channel, timepoint) for z in z_range])
                else:
                    # Get all z-slices
                    img_3d = np.stack([pixels.getPlane(z, channel, timepoint) for z in range(image.getSizeZ())])
                images.append(img_3d)
                
                print(f"Creating 3D dask array for image {image.getId()} with shape {img_3d.shape}")
                dask_img = get_dask_image(conn, image.getId(), timepoint=timepoint, channel=channel, three_d=True)
                dask_images.append(dask_img)
            else:
                # For 2D mode, get a single plane
                pixels = image.getPrimaryPixels()
                img = pixels.getPlane(z_slice, channel, timepoint)
                images.append(img)
                
                print(f"Creating 2D dask array for image {image.getId()}")
                dask_img = get_dask_image(conn, image.getId(), z_slice=z_slice, 
                                       timepoint=timepoint, channel=channel)
                dask_images.append(dask_img)
        
        # Process batch with SAM using standard numpy arrays for now
        # Note: In the future, micro-sam could be updated to work directly with dask
        print("Starting napari viewer with SAM annotator. Close the viewer window when done.")
        
        # Create viewer without context management - following recommended approach
        # See error trace for the warnings about gui_qt() being deprecated
        viewer = napari.Viewer()
        
        # Add image series annotator
        image_series_annotator(
            images, 
            model_type=model_type,
            viewer=viewer,
            embedding_path=os.path.join(output_folder, "embed"),
            output_folder=os.path.join(output_folder, "output"),
            is_volumetric=three_d  # Pass the three_d flag to use 3D mode
        )
        
        # Start the napari application - this blocks until the viewer is closed
        try:
            napari.run()
            print("Napari viewer closed.")
        except KeyboardInterrupt:
            print("Napari viewer was interrupted. Processing results anyway...")
        except Exception as e:
            print(f"Error in napari: {e}")
            
        print("Processing results from batch...")
        print("Done annotating batch, storing results in zarr and uploading to OMERO")
        
        # Initialize batch progress tracking
        batch_completed = 0
        batch_skipped = 0
        
        # Process results for batch
        batch_df = pd.DataFrame(columns=df.columns)
        
        for n, image in enumerate(batch_images):
            local_n = n  # Index within current batch
            global_n = start_idx + n  # Global index across all batches
            
            # Store segmentation mask in zarr before uploading to OMERO
            seg_file_path = os.path.join(output_folder, "output", f"seg_{local_n:05d}.tif")
            if not os.path.exists(seg_file_path):
                print(f"Warning: Segmentation file not found for image {image.getId()}, skipping")
                batch_skipped += 1
                
                # Add a row for skipped image but mark as not processed
                is_train = combined_images_sequence[global_n] == 0 if not segment_all else True
                is_validate = combined_images_sequence[global_n] == 1 if not segment_all else False
                
                # Record z-slice information - for 3D we store the range
                z_info = z_range if three_d and z_range is not None else \
                       (f"0-{image.getSizeZ()-1}" if three_d else z_slice)
                
                new_row = pd.DataFrame([{
                    "image_id": image.getId(),
                    "image_name": image.getName(),
                    "train": is_train,
                    "validate": is_validate,
                    "channel": channel,
                    "z_slice": z_info,
                    "timepoint": timepoint,
                    "sam_model": model_type,
                    "embed_id": None,
                    "label_id": None,
                    "roi_id": None,
                    "is_volumetric": three_d,
                    "processed": False
                }])
                batch_df = pd.concat([batch_df, new_row], ignore_index=True)
                continue
                
            batch_completed += 1
            
            # Read the segmentation mask
            mask_data = imageio.imread(seg_file_path)
            
            # Store in zarr format for efficient processing
            zarr_file_path = store_annotations_in_zarr(mask_data, zarr_path, global_n)
            
            # Store embedding in zarr format and zip for OMERO upload
            embed_zarr = f"embedding_{local_n:05d}.zarr"
            embed_dir = os.path.join(output_folder, "embed")
            zip_path = os.path.join(embed_dir, f"embedding_{global_n:05d}.zip")
            
            # Check if the embedding directory exists before trying to zip it
            embed_zarr_path = os.path.join(embed_dir, embed_zarr)
            if not os.path.exists(embed_zarr_path):
                print(f"Warning: Embedding directory {embed_zarr} not found, skipping embedding upload")
                embed_id = None
            else:
                with zipfile.ZipFile(zip_path, 'w') as zip_file:
                    zip_directory(embed_dir, embed_zarr, zip_file)
                
                # Upload embedding to OMERO
                embed_id = ezomero.post_file_annotation(
                    conn,
                    str(zip_path),
                    ns='microsam.embeddings',
                    object_type="Image",
                    object_id=image.getId(),
                    description=f'SAM embedding ({model_type}), 3D={three_d}'
                )
            
            # Convert zarr annotation to TIFF for OMERO compatibility
            tiff_path = os.path.join(output_folder, "output", f"seg_{global_n:05d}.tiff")
            zarr_to_tiff(zarr_file_path, tiff_path)
            
            # Upload labels and create ROIs
            # For 3D data, we need to handle the z-dimension correctly when uploading ROIs
            if three_d:
                # For 3D data, z_slice is actually all z-slices
                z_for_roi = z_range if z_range is not None else range(image.getSizeZ())
                # Pass the appropriate parameters for 3D processing
                label_id, roi_id = upload_rois_and_labels(
                    conn, 
                    image, 
                    tiff_path, 
                    z_for_roi,  # Pass the full range of z-slices
                    channel, 
                    timepoint, 
                    model_type,
                    is_volumetric=True
                )
            else:
                # For 2D data, just pass the z_slice as before
                label_id, roi_id = upload_rois_and_labels(
                    conn, 
                    image, 
                    tiff_path, 
                    z_slice, 
                    channel, 
                    timepoint, 
                    model_type,
                    is_volumetric=False
                )
            
            # Update tracking dataframe
            is_train = combined_images_sequence[global_n] == 0 if not segment_all else True
            is_validate = combined_images_sequence[global_n] == 1 if not segment_all else False
            
            # Record z-slice information - for 3D we store the range
            z_info = z_range if three_d and z_range is not None else \
                   (f"0-{image.getSizeZ()-1}" if three_d else z_slice)
            
            new_row = pd.DataFrame([{
                "image_id": image.getId(),
                "image_name": image.getName(),
                "train": is_train,
                "validate": is_validate,
                "channel": channel,
                "z_slice": z_info,
                "timepoint": timepoint,
                "sam_model": model_type,
                "embed_id": embed_id,
                "label_id": label_id,
                "roi_id": roi_id,
                "is_volumetric": three_d,
                "processed": True
            }])
            batch_df = pd.concat([batch_df, new_row], ignore_index=True)
        
        # Update the main DataFrame with the batch results
        df = pd.concat([df, batch_df], ignore_index=True)
        
        # Upload batch tracking table to OMERO
        if table_id is not None:
            # Delete the existing table before creating a new one
            try:
                print(f"Deleting existing table with ID: {table_id}")
                # Get the file annotation object for the table
                ann = conn.getObject("FileAnnotation", table_id)
                if ann:
                    # Delete the file annotation (which contains the table)
                    conn.deleteObjects("FileAnnotation", [table_id], wait=True)
                    print(f"Existing table deleted successfully")
                else:
                    print(f"Warning: Could not find table with ID: {table_id}")
            except Exception as e:
                print(f"Warning: Could not delete existing table: {e}")
                # Continue anyway, as we'll create a new table
        
        # Create a new table with the updated data
        table_id = ezomero.post_table(
            conn, 
            object_type=container_type.capitalize(), 
            object_id=container_id, 
            table=df,
            title="micro_sam_training_data"
        )
        if table_id is None:
            print("Warning: Failed to create tracking table")
        else:
            print(f"Created new tracking table with ID: {table_id}")
        
        print(f"Batch {batch_idx+1}/{total_batches} results:")
        print(f"  - Completed: {batch_completed}/{len(batch_images)} images")
        print(f"  - Skipped: {batch_skipped}/{len(batch_images)} images")
        
        if batch_skipped > 0 and batch_idx < total_batches - 1:
            # Ask user if they want to continue with next batch or stop here
            try:
                response = input("Some images were skipped. Continue with next batch? (y/n): ")
                if response.lower() not in ['y', 'yes']:
                    print("Stopping processing at user request.")
                    break
            except:
                # In case of non-interactive environment, continue by default
                print("Non-interactive environment detected. Continuing with next batch.")
        
        # Clean up temporary files for this batch
        for n in range(batch_size):  # Use local indexing for cleanup
            if start_idx + n >= len(combined_images):  # Skip if we've processed all images
                continue
                
            embed_zip = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zip")
            embed_zarr = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zarr")
            seg_file = os.path.join(output_folder, "output", f"seg_{n:05d}.tif")
            
            for path in [embed_zip, seg_file]:
                if os.path.exists(path):
                    os.remove(path)
                    
            if os.path.exists(embed_zarr) and os.path.isdir(embed_zarr):
                shutil.rmtree(embed_zarr)
    
    # Final statistics
    total_processed = df[df['processed'] == True].shape[0]
    total_skipped = df[df['processed'] == False].shape[0]
    
    print(f"\nAll batches completed.")
    print(f"Total processed: {total_processed} images")
    print(f"Total skipped: {total_skipped} images")
    print(f"Final tracking table ID: {table_id} in {source_desc}")
    
    return table_id, combined_images

### Load images from OMERO and open in napari with micro-sam annotator

When using 3D mode (`three_d=True`), the notebook will process entire Z-stacks instead of single slices. This allows for volumetric annotation using micro-SAM's 3D capabilities.

### Running segmentation in batch

Note: some warnings from napari are expected in the output here, generally not a problem

In [None]:
##imput parameters
model_type = 'vit_b'
segment_all = False
train_n = 2   
validate_n = 2
channel = 3 #which channel to segment starting from 0
timepoint = 0
z_slice = 4 #TODO for now pick one slice but add option to pick multiple slices by giving a list of z slices, or random slices
batch_size = 2 # the number of images to process at once in napari
three_d = False
resume_from_table = False  # Set to True to continue from a previous run

# Configure napari settings
settings = get_settings()
settings.application.ipy_interactive = False

# Run batch processing with dask lazy loading
# Get all images from the specified container
images_list, source_desc = get_images_from_container(conn, datatype, data_id)

if len(images_list) > 0:
    table_id, processed_images = process_omero_batch_with_dask(
        images_list=images_list,
        output_folder=output_directory,
        container_type=datatype,
        container_id=data_id,
        source_desc=source_desc,
        model_type=model_type,
        batch_size=batch_size,
        channel=channel,
        timepoint=timepoint,
        z_slice=z_slice,
        segment_all=segment_all,
        train_n=train_n,
        validate_n=validate_n,
        resume_from_table=resume_from_table
    )
    print(f"Finished processing with dask lazy loading. Table ID: {table_id}")
    print(f"To resume this session later, set resume_from_table=True")
else:
    print(f"No images found in the {datatype} with ID {data_id}")



## Resuming Annotation Sessions

This notebook now supports resuming annotation sessions. If you need to stop annotating and continue later:

1. Set `resume_from_table = True` in the parameters section
2. Run the notebook as usual
3. The system will automatically detect previously annotated images and continue with the remaining ones

This is useful for:
- Long annotation sessions that need to be split over multiple days
- Cases where napari was closed accidentally
- Continuing after computer restarts or crashes

The tracking table in OMERO keeps track of which images have been successfully processed.

## Examples for processing different container types

Here are examples showing how to set up the notebook for different OMERO container types:

### Processing a plate
```python
datatype = "plate"
data_id = 101  # Your plate ID
model_type = 'vit_b'
batch_size = 10
channel = 0
z_slice = 0  # Only used when three_d=False
z_range = None  # Optional: specify a range of z-slices (e.g., range(3, 8))
timepoint = 0
segment_all = False
train_n = 20
validate_n = 10
three_d = False  # Set to True for 3D volumetric processing
```

### Processing a screen
```python
datatype = "screen"
data_id = 5  # Your screen ID
model_type = 'vit_b'
batch_size = 10
channel = 0
z_slice = 0
z_range = None
timepoint = 0
segment_all = False
train_n = 20
validate_n = 10
three_d = False
```

### Processing a project
```python
datatype = "project"
data_id = 201  # Your project ID
model_type = 'vit_b'
batch_size = 10
channel = 0
z_slice = 0
z_range = None
timepoint = 0
segment_all = False
train_n = 20
validate_n = 10
three_d = False
```

### Processing in 3D mode
```python
datatype = "dataset"
data_id = 201  # Your dataset ID
model_type = 'vit_b'
batch_size = 5  # Smaller batch size for 3D as it requires more memory
channel = 0
z_range = range(5, 15)  # Optional: process only a subset of z-slices
timepoint = 0
segment_all = False
train_n = 5
validate_n = 5
three_d = True  # Enable 3D volumetric processing
```