### Fine tuning SAM with OMERO data using a batch approach - Enhanced Version

### Features
- Supports multiple OMERO data types (single images, datasets, projects, plates, and screens)
- Batch processing with micro-SAM for segmentation
- Stores all annotations in OMERO as ROIs and attachments
- Uses dask for lazy loading of images for better memory management
- Supports 3D volumetric segmentation for z-stacks
- **NEW**: Support for multiple z-slices in 2D mode
- **NEW**: Support for time series analysis
- **NEW**: Support for patch-based extraction and annotation
- **NEW**: Improved resumption of annotation sessions

### TODOs
- Store all annotations into OMERO, see: https://github.com/computational-cell-analytics/micro-sam/issues/445; in series annotator possible to add commit path with prompts, but they get overwritten
- Clean up the errors and warnings output from napari
- Improve ROI creation for 3D volumes to better represent volumetric masks in OMERO
- Work with Dask arrays directly in micro-sam
- Add recovery mode to handle cases when users abort in the middle of a batch annotation session (currently annotations made before closing are preserved, but could be improved with a dedicated recovery workflow)

Instructions:
  - To make it easier to run with OMERO and to not expose login and passwords, password is stored in .env file (see example .env_example). Still it is not recommended to save credentials unencrypted hence a better solution will be worked on.
  - This notebook supports processing images from various OMERO container types: images, datasets, projects, plates, and screens.
  - Specify the container type in the `datatype` variable and the container ID in the `data_id` variable.
  - You can choose to segment all images in the container or select a random subset for training and validation.
  - **NEW**: You can now specify multiple z-slices and timepoints to analyze.
  - **NEW**: You can extract and analyze patches from large images.

In [None]:
# OMERO-related imports
import omero
from omero.gateway import BlitzGateway
import ezomero

# Scientific computing and image processing
import cv2
import imageio.v3 as imageio
import numpy as np
import pandas as pd

# File and system operations
import os
import shutil
import tempfile
import zipfile
import warnings
from tifffile import imwrite, imread
from dotenv import load_dotenv

# Dask and Zarr for lazy loading and processing
import dask
import dask.array as da
import zarr

# Micro-SAM and Napari
from napari.settings import get_settings
from micro_sam.sam_annotator import image_series_annotator, annotator_2d
from micro_sam.util import precompute_image_embeddings
import napari

import json

class NumpyEncoder(json.JSONEncoder):
    """Custom encoder for numpy data types"""
    def default(self, obj):
        if isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)
conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), group=os.environ.get("GROUP"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "screen", "plate", "project", "dataset", "image"
data_id = 1112
nucl_channel = 0

def print_object_details(conn, obj, datatype):
    """Print detailed information about OMERO objects"""
    print(f"\n{datatype.capitalize()} Details:")
    print(f"- Name: {obj.getName()}")
    print(f"- ID: {obj.getId()}")
    print(f"- Owner: {obj.getOwner().getFullName()}")
    print(f"- Group: {obj.getDetails().getGroup().getName()}")
    
    if datatype == "project":
        datasets = list(obj.listChildren())
        dataset_count = len(datasets)
        total_images = sum(len(list(ds.listChildren())) for ds in datasets)
        print(f"- Number of datasets: {dataset_count}")
        print(f"- Total images: {total_images}")
        
    elif datatype == "plate":
        wells = list(obj.listChildren())
        well_count = len(wells)
        print(f"- Number of wells: {well_count}")
        
    elif datatype == "dataset":
        images = list(obj.listChildren())
        image_count = len(images)
        # Get project info if dataset is in a project
        projects = obj.getParent()
        if projects:
            print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Project: None (orphaned dataset)")
        print(f"- Number of images: {image_count}")
        
    elif datatype == "image":
        size_x = obj.getSizeX()
        size_y = obj.getSizeY()
        size_z = obj.getSizeZ()
        size_c = obj.getSizeC()
        size_t = obj.getSizeT()
        # Get dataset info if image is in a dataset
        datasets = obj.getParent()
        if datasets:
            print(f"- Dataset: {datasets.getName()} (ID: {datasets.getId()})")
            # Get project info if dataset is in a project
            projects = datasets.getParent()
            if projects:
                print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Dataset: None (orphaned image)")
        print(f"- Dimensions: {size_x}x{size_y}")
        print(f"- Z-stack: {size_z}")
        print(f"- Channels: {size_c}")
        print(f"- Timepoints: {size_t}")

# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")
    
elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")
    
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")
    
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")

### Create temporary folder to store training data, this will be uploaded to OMERO later

In [None]:
output_directory = os.path.normcase(tempfile.mkdtemp())
print('Output Directory: ', output_directory)

In [None]:
def zip_directory(source_path, zarr_path, zip_file):
    """Zip a directory while handling null characters in paths."""
    for root, dirs, files in os.walk(zarr_path):
        for file in files:
            try:
                # Create paths
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, source_path)
                
                # Remove null characters while preserving the path structure
                safe_full_path = full_path.replace('\x00', '')
                safe_rel_path = rel_path.replace('\x00', '')
                
                # Add file to zip if it exists
                if os.path.exists(safe_full_path):
                    zip_file.write(safe_full_path, safe_rel_path)
            except Exception as e:
                print(f"Warning: Error processing {file}: {str(e)}")
                continue

def interleave_arrays(train_images, validate_images):
    """
    Interleave two arrays of images in the pattern: train[0], validate[0], train[1], validate[1], ...
    If arrays are of unequal length, remaining elements are appended at the end.
    """
    # Create empty list to store interleaved images
    interleaved = []
    sequence = []
    # Get the length of the longer array
    max_len = max(len(train_images), len(validate_images))
    
    # Interleave the arrays
    for i in range(max_len):
        # Add train image if available
        if i < len(train_images):
            interleaved.append(train_images[i])
            sequence.append(0)
        # Add validate image if available
        if i < len(validate_images):
            interleaved.append(validate_images[i])
            sequence.append(1)
    
    return np.array(interleaved), np.array(sequence)

def mask_to_contour(mask):
    """Converts a binary mask to a list of ROI coordinates.

    Args:
        mask (np.ndarray): binary mask

    Returns:
        list: list of ROI coordinates
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def label_to_rois(label_img, z_slice, channel, timepoint, model_type, is_volumetric=False, patch_offset=None):
    """
    Convert a 2D or 3D label image to OMERO ROI shapes
    
    Args:
        label_img (np.ndarray): 2D labeled image or 3D labeled stack
        z_slice (int or list): Z-slice index or list/range of Z indices
        channel (int): Channel index
        timepoint (int): Time point index
        model_type (str): SAM model type used
        is_volumetric (bool): Whether the label image is 3D volumetric data
        patch_offset: Optional (x,y) offset for placing ROIs in a larger image
    
    Returns:
        list: List of OMERO shape objects
    """
    shapes = []
    
    # Unpack patch offset if provided
    x_offset, y_offset = (0, 0) if patch_offset is None else patch_offset
    
    if is_volumetric and label_img.ndim > 2:
        # 3D volumetric data - process each z slice
        for z_index, z_plane in enumerate(label_img):
            # If z_slice is a range or list, use the actual z-index from that range
            if isinstance(z_slice, (range, list)):
                actual_z = z_slice[z_index] if z_index < len(z_slice) else z_slice[0] + z_index
            else:
                actual_z = z_slice + z_index  # Assume z_slice is the starting index
                
            print(f"Processing volumetric ROIs for z-slice {actual_z}")
            shapes.extend(process_label_plane(z_plane, actual_z, channel, timepoint, model_type, 
                                            x_offset, y_offset))
    else:
        # 2D data - process single plane
        shapes.extend(process_label_plane(label_img, z_slice, channel, timepoint, model_type, 
                                        x_offset, y_offset))
    
    return shapes

def process_label_plane(label_plane, z_slice, channel, timepoint, model_type, x_offset=0, y_offset=0):
    """Process a single 2D label plane to generate OMERO shapes with optional offset"""
    shapes = []
    unique_labels = np.unique(label_plane)
    
    # Skip background (label 0)
    for label in unique_labels[1:]:
        # Create binary mask for this label
        mask = (label_plane == label).astype(np.uint8)
        
        # Get contours
        contours = mask_to_contour(mask)
        
        # Convert each contour to polygon ROI
        for contour in contours:
            contour = contour[:, 0, :]  # Reshape to (N, 2)
            
            # Apply offset to contour points if needed
            if x_offset != 0 or y_offset != 0:
                contour = contour + np.array([x_offset, y_offset])
                
            # Create polygon without text parameter
            poly = ezomero.rois.Polygon(
                points=contour,  # explicitly name the points parameter
                z=z_slice,
                c=channel,
                t=timepoint,
                label=f'micro_sam.{"volumetric" if isinstance(z_slice, (list, range)) or z_slice > 0 else "manual"}_instance_segmentation.{model_type}'
            )
            shapes.append(poly)
    
    return shapes

def upload_rois_and_labels(conn, image, label_file, z_slice, channel, timepoint, model_type, 
                          is_volumetric=False, patch_offset=None, read_only_mode=False, local_output_dir="./omero_annotations"):
    """
    Upload both label map and ROIs for a segmented image or save them locally in read-only mode
    
    Args:
        conn: OMERO connection
        image: OMERO image object
        label_file: Path to the label image file
        z_slice: Z-slice index or range of indices
        channel: Channel index
        timepoint: Time point index
        model_type: SAM model type used
        is_volumetric: Whether the data is 3D volumetric
        patch_offset: Optional (x,y) offset for placing ROIs in a larger image
        read_only_mode: If True, save annotations locally instead of uploading to OMERO
        local_output_dir: Directory to save local annotations when in read-only mode
    
    Returns:
        tuple: (label_id, roi_id) or (local_label_path, local_roi_path) in read-only mode
    """
    # Add patch info to description if applicable
    patch_desc = ""
    if patch_offset:
        patch_desc = f", Patch offset: ({patch_offset[0]}, {patch_offset[1]})"
    
    # Create ROIs from label image
    label_img = imageio.imread(label_file)
    shapes = label_to_rois(label_img, z_slice, channel, timepoint, model_type, 
                          is_volumetric, patch_offset)
    
    if read_only_mode:
        # Save annotations locally instead of uploading to OMERO
        import json
        import os
        import shutil
        
        # Create local directories
        image_id = image.getId()
        image_dir = os.path.join(local_output_dir, f"image_{image_id}")
        os.makedirs(image_dir, exist_ok=True)
        
        # Save label image file
        local_label_path = os.path.join(image_dir, os.path.basename(label_file))
        shutil.copy2(label_file, local_label_path)
        
        # Save ROI data as JSON
        local_roi_path = os.path.join(image_dir, f"roi_{os.path.basename(label_file).split('.')[0]}.json")
        
        # Prepare ROI metadata
        roi_metadata = {
            "image_id": image_id,
            "image_name": image.getName(),
            "timestamp": str(pd.Timestamp.now()),
            "model_type": model_type,
            "is_volumetric": is_volumetric,
            "z_slice": z_slice if not isinstance(z_slice, range) else list(z_slice),
            "channel": channel,
            "timepoint": timepoint,
            "patch_offset": patch_offset,
            "shapes_count": len(shapes) if shapes else 0,
            # We can't store the actual shapes because they're OMERO objects
            # but we can save the label image which can be used to recreate them
            "label_image_path": os.path.relpath(local_label_path, local_output_dir)
        }
        
        # Save metadata
        with open(local_roi_path, 'w') as f:
            json.dump(roi_metadata, f, indent=2, cls=NumpyEncoder)
            
        print(f"Saved annotation locally in read-only mode to {image_dir}")
        return local_label_path, local_roi_path
    else:
        # Normal OMERO upload mode
        # Upload label map as attachment
        label_id = ezomero.post_file_annotation(
            conn,
            str(label_file),
            ns='microsam.labelimage',
            object_type="Image",
            object_id=image.getId(),
            description=f'SAM {"volumetric" if is_volumetric else "manual"} segmentation ({model_type}){patch_desc}'
        )
        
        if shapes:  # Only create ROI if shapes were found
            roi_id = ezomero.post_roi(
                conn,
                image.getId(),
                shapes,
                name=f'SAM_{model_type}{"_3D" if is_volumetric else ""}{patch_desc}',
                description=f'micro_sam.{"volumetric" if is_volumetric else "manual"}_instance_segmentation.{model_type}{patch_desc}'
            )
        else:
            roi_id = None
            
        return label_id, roi_id

In [None]:
def get_images_from_container(conn, datatype, container_id):
    """
    Extract all images from a given OMERO container (Project, Dataset, Plate, Screen)
    
    Args:
        conn: OMERO connection
        datatype: Type of container ('project', 'dataset', 'plate', 'screen', 'image')
        container_id: ID of the container
        
    Returns:
        list: List of OMERO image objects
        str: Description of the source (for tracking)
    """
    images = []
    source_desc = ""
    
    if datatype == "image":
        image = conn.getObject("Image", container_id)
        if image is None:
            raise ValueError(f"Image with ID {container_id} not found")
        images = [image]
        source_desc = f"Image: {image.getName()} (ID: {container_id})"
    
    elif datatype == "dataset":
        dataset = conn.getObject("Dataset", container_id)
        if dataset is None:
            raise ValueError(f"Dataset with ID {container_id} not found")
        images = list(dataset.listChildren())
        source_desc = f"Dataset: {dataset.getName()} (ID: {container_id})"
    
    elif datatype == "project":
        project = conn.getObject("Project", container_id)
        if project is None:
            raise ValueError(f"Project with ID {container_id} not found")
        # Get all datasets in the project
        for dataset in project.listChildren():
            # Get all images in each dataset
            for image in dataset.listChildren():
                images.append(image)
        source_desc = f"Project: {project.getName()} (ID: {container_id})"
    
    elif datatype == "plate":
        plate = conn.getObject("Plate", container_id)
        if plate is None:
            raise ValueError(f"Plate with ID {container_id} not found")
        # Get all wells in the plate
        for well in plate.listChildren():
            # Get all images (fields) in each well
            for wellSample in well.listChildren():
                images.append(wellSample.getImage())
        source_desc = f"Plate: {plate.getName()} (ID: {container_id})"
    
    elif datatype == "screen":
        screen = conn.getObject("Screen", container_id)
        if screen is None:
            raise ValueError(f"Screen with ID {container_id} not found")
        # Get all plates in the screen
        for plate in screen.listChildren():
            # Get all wells in each plate
            for well in plate.listChildren():
                # Get all images (fields) in each well
                for wellSample in well.listChildren():
                    images.append(wellSample.getImage())
        source_desc = f"Screen: {screen.getName()} (ID: {container_id})"
    
    else:
        raise ValueError(f"Unsupported datatype: {datatype}")
    
    print(f"Found {len(images)} images from {source_desc}")
    return images, source_desc

### New patch generation functions for extracting image regions

In [None]:
def generate_patch_coordinates(image_width, image_height, patch_size, num_patches, random_patches=True):
    """
    Generate coordinates for image patches
    
    Args:
        image_width: Width of the full image
        image_height: Height of the full image
        patch_size: Tuple of (width, height) for the patch
        num_patches: Number of patches to generate
        random_patches: If True, generate random patches; if False, generate centered patches
        
    Returns:
        list: List of patch coordinates as tuples (x, y, width, height)
    """
    patch_width, patch_height = patch_size
    
    # Ensure patch size is not larger than image
    patch_width = min(patch_width, image_width)
    patch_height = min(patch_height, image_height)
    
    patches = []
    
    if random_patches:
        # Generate random patches
        for _ in range(num_patches):
            # Calculate valid coordinate ranges
            max_x = image_width - patch_width
            max_y = image_height - patch_height
            
            if max_x <= 0 or max_y <= 0:
                # Image is too small for the patch, use full image
                patches.append((0, 0, image_width, image_height))
            else:
                # Generate random coordinates
                x = np.random.randint(0, max_x + 1)
                y = np.random.randint(0, max_y + 1)
                patches.append((x, y, patch_width, patch_height))
    else:
        # Generate centered patch
        x = (image_width - patch_width) // 2
        y = (image_height - patch_height) // 2
        
        # Add the centered patch (potentially multiple times if num_patches > 1)
        for _ in range(num_patches):
            patches.append((x, y, patch_width, patch_height))
    
    return patches

def extract_patch(image_array, patch_coords):
    """
    Extract a patch from an image array
    
    Args:
        image_array: Numpy array containing the image data
        patch_coords: Tuple of (x, y, width, height)
        
    Returns:
        numpy.ndarray: Extracted patch
    """
    x, y, width, height = patch_coords
    
    # Handle different dimensionality
    if image_array.ndim == 2:
        # 2D image
        return image_array[y:y+height, x:x+width]
    elif image_array.ndim == 3:
        # 3D image (z-stack or multi-channel)
        return image_array[:, y:y+height, x:x+width]
    else:
        # Higher dimensions (e.g., z-stack + multi-channel)
        return image_array[..., y:y+height, x:x+width]

### Dask Lazy Loading Functions for OMERO Data

In [None]:
def get_dask_image(conn, image_id, z_slice=None, timepoint=None, channel=None, three_d=False, patch_coords=None):
    """
    Get a dask array representation of an OMERO image for lazy loading
    
    Args:
        conn: OMERO connection
        image_id: ID of image to load
        z_slice: Optional specific Z slice to load (int or list)
        timepoint: Optional specific timepoint to load (int or list)
        channel: Optional specific channel to load (int or list)
        three_d: Whether to load a 3D volume (all z-slices) instead of a single slice
        patch_coords: Optional tuple of (x, y, width, height) to extract a patch
    
    Returns:
        dask array representation of image
    """
    image = conn.getObject("Image", image_id)
    pixels = image.getPrimaryPixels()
    
    # Get image dimensions
    size_z = image.getSizeZ()
    size_c = image.getSizeC()
    size_t = image.getSizeT()
    size_y = image.getSizeY()
    size_x = image.getSizeX()
    
    # Define specific dimensions to load if provided
    # If three_d is True, we want all z-slices, otherwise use the provided z_slice
    if three_d:
        z_range = range(size_z)  # Load all z-slices for 3D
    else:
        z_range = [z_slice] if isinstance(z_slice, int) else (range(size_z) if z_slice is None else z_slice)
    
    t_range = [timepoint] if isinstance(timepoint, int) else (range(size_t) if timepoint is None else timepoint)
    c_range = [channel] if isinstance(channel, int) else (range(size_c) if channel is None else channel)
    
    # Extract patch information if provided
    x_offset = 0
    y_offset = 0
    if patch_coords:
        x_offset, y_offset, patch_width, patch_height = patch_coords
        size_x = patch_width
        size_y = patch_height
    
    # Create empty dict to store delayed objects
    delayed_planes = {}
    
    desc = "patch" if patch_coords else "image"
    print(f"Creating dask array for {desc} {image_id} with lazy loading")
    print(f"Dimensions: Z={len(z_range)}, C={len(c_range)}, T={len(t_range)}, Y={size_y}, X={size_x}")
    print(f"3D mode: {three_d}")
    
    # Create lazy loading function
    @dask.delayed
    def get_plane(z, c, t):
        print(f"Loading plane: Z={z}, C={c}, T={t}")
        if patch_coords:
            full_plane = pixels.getPlane(z, c, t)
            return full_plane[y_offset:y_offset+size_y, x_offset:x_offset+size_x]
        else:
            return pixels.getPlane(z, c, t)
    
    # Build dask arrays
    arrays = []
    for t in t_range:
        t_arrays = []
        for z in z_range:
            z_arrays = []
            for c in c_range:
                # Create a key for this plane
                key = (z, c, t)
                
                # Check if we've already created this delayed object
                if key not in delayed_planes:
                    # Create a delayed object for this plane
                    delayed_plane = get_plane(z, c, t)
                    delayed_planes[key] = delayed_plane
                else:
                    delayed_plane = delayed_planes[key]
                
                # Convert to dask array with known shape and dtype
                shape = (size_y, size_x)
                dtype = np.uint16  # Most OMERO images use 16-bit
                dask_plane = da.from_delayed(delayed_plane, shape=shape, dtype=dtype)
                z_arrays.append(dask_plane)
            if z_arrays:
                # Stack channels for this z position
                t_arrays.append(da.stack(z_arrays))
        if t_arrays:
            # Stack z-planes for this timepoint
            arrays.append(da.stack(t_arrays))
    
    if arrays:
        # Stack all timepoints
        return da.stack(arrays)
    else:
        return None

def store_annotations_in_zarr(mask_data, output_folder, image_num):
    """
    Store annotation masks in zarr format for efficient access
    
    Args:
        mask_data: Numpy array with mask data
        output_folder: Base folder to store zarr data
        image_num: Image number/identifier
        
    Returns:
        path: Path to the zarr store
    """
    # Create zarr directory if it doesn't exist
    zarr_dir = os.path.join(output_folder, "annotations")
    os.makedirs(zarr_dir, exist_ok=True)
    
    # Create zarr filename
    zarr_path = os.path.join(zarr_dir, f"annotation_{image_num:05d}.zarr")
    
    # Remove existing zarr store if it exists
    if os.path.exists(zarr_path):
        shutil.rmtree(zarr_path)
        
    # Create zarr array from mask data
    z = zarr.open(zarr_path, mode='w')
    z.create_dataset('masks', data=mask_data, chunks=(256, 256))
    
    # Return path to zarr store
    return zarr_path

def zarr_to_tiff(zarr_path, output_tiff_path):
    """
    Convert zarr store to TIFF file for OMERO upload
    
    Args:
        zarr_path: Path to zarr store
        output_tiff_path: Path to save TIFF file
        
    Returns:
        output_tiff_path: Path to saved TIFF file
    """
    # Load data from zarr
    z = zarr.open(zarr_path, mode='r')
    mask_data = z['masks'][:]
    
    # Save as TIFF
    imwrite(output_tiff_path, mask_data)
    
    return output_tiff_path

def cleanup_local_embeddings(output_folder):
    """
    Check for and clean up any existing embeddings from previous interrupted runs
    
    Args:
        output_folder: Path to the output folder containing embeddings
    """
    embed_path = os.path.join(output_folder, "embed")
    
    if os.path.exists(embed_path):
        # Look for embedding zarr directories and zip files
        for item in os.listdir(embed_path):
            item_path = os.path.join(embed_path, item)
            if os.path.isdir(item_path) and "embedding_" in item and item.endswith(".zarr"):
                print(f"Cleaning up leftover embedding directory: {item}")
                shutil.rmtree(item_path)
            elif os.path.isfile(item_path) and "embedding_" in item and item.endswith(".zip"):
                print(f"Cleaning up leftover embedding zip: {item}")
                os.remove(item_path)
    
    # Check output directory for segmentation files
    output_path = os.path.join(output_folder, "output")
    if os.path.exists(output_path):
        for item in os.listdir(output_path):
            item_path = os.path.join(output_path, item)
            if os.path.isfile(item_path) and "seg_" in item and (item.endswith(".tif") or item.endswith(".tiff")):
                print(f"Cleaning up leftover segmentation file: {item}")
                os.remove(item_path)

In [None]:
def process_omero_batch_with_dask(
    images_list,
    output_folder: str,
    container_type: str,
    container_id: int,
    source_desc: str,
    model_type: str = 'vit_l',
    batch_size: int = 3,
    channel: int = 0,
    timepoints: list = [0],
    timepoint_mode: str = "specific",
    z_slices: list = [0],
    z_slice_mode: str = "specific",
    segment_all: bool = True,
    train_n: int = 3,
    validate_n: int = 3,
    three_d: bool = False,
    use_patches: bool = False,
    patch_size: tuple = (512, 512),
    patches_per_image: int = 1,
    random_patches: bool = True,
    resume_from_table: bool = False,
    read_only_mode: bool = False,
    local_output_dir: str = "./omero_annotations"
):
    """
    Process OMERO images in batches for SAM segmentation using dask for lazy loading
    and zarr for temporary annotation storage
    
    Args:
        images_list: List of OMERO image objects
        output_folder: Path to store temporary files
        container_type: Type of OMERO container ('dataset', 'plate', 'project', 'screen', 'image')
        container_id: ID of the container
        source_desc: Description of the container (for tracking)
        model_type: SAM model type
        batch_size: Number of images/patches to process at once
        channel: Channel to segment
        timepoints: List of timepoints to process
        timepoint_mode: How to handle timepoints ("all", "random", "specific")
        z_slices: List of Z-slices to process (used only when three_d=False)
        z_slice_mode: How to handle z-slices ("all", "random", "specific")
        segment_all: Segment all images in the dataset or only train/validate subset
        train_n: Number of training images if not segment_all
        validate_n: Number of validation images if not segment_all
        three_d: Whether to use 3D volumetric mode
        use_patches: Whether to extract and process patches instead of full images
        patch_size: Size of patches to extract (width, height)
        patches_per_image: Number of patches to extract from each image (if random_patches=True)
        random_patches: Whether to extract random patches or centered patches
        resume_from_table: Whether to resume annotation from an existing tracking table
    
    Returns:
        tuple: (table_id, combined_images)
    """
    # Setup output directories
    output_path = os.path.join(output_folder, "output")
    embed_path = os.path.join(output_folder, "embed")
    zarr_path = os.path.join(output_folder, "zarr")
    
    # Check for and clean up any existing embeddings from interrupted runs
    cleanup_local_embeddings(output_folder)
    
    # Remove directories if they exist
    for path in [output_path, embed_path, zarr_path]:
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path)
        
    # Create or retrieve tracking DataFrame with additional columns for the new features
    df = pd.DataFrame(columns=[
        "image_id", "image_name", "train", "validate", 
        "channel", "z_slice", "timepoint", "sam_model", "embed_id", "label_id", "roi_id", 
        "is_volumetric", "processed", "is_patch", "patch_x", "patch_y", "patch_width", "patch_height",
        "schema_attachment_id"  # New column for schema attachment
    ])
    
    table_id = None
    
    # Check if we should resume from an existing table
    if resume_from_table:
        try:
            # Get existing tracking table
            existing_tables = ezomero.get_table_names(conn, container_type.capitalize(), container_id)
            if "micro_sam_training_data" in existing_tables:
                # Get the table ID and data
                table_ids = ezomero.get_table_ids(conn, container_type.capitalize(), container_id)
                for tid in table_ids:
                    table_name = ezomero.get_table_names(conn, container_type.capitalize(), container_id, tid)
                    if table_name == "micro_sam_training_data":
                        table_id = tid
                        existing_df = ezomero.get_table(conn, table_id)
                        
                        # Add any missing columns (for backward compatibility)
                        for col in df.columns:
                            if col not in existing_df.columns:
                                existing_df[col] = None
                        
                        # Ensure schema_attachment_id column exists if resuming
                        if 'schema_attachment_id' not in existing_df.columns:
                            existing_df['schema_attachment_id'] = None
                                
                        df = existing_df
                        
                        print(f"Resuming from existing table ID: {table_id}")
                        print(f"Found {len(df)} previously processed images/patches")
                        break
        except Exception as e:
            print(f"Error retrieving existing table: {e}. Starting fresh.")
            resume_from_table = False
    
    # Get images list (already provided as argument)
    combined_images_sequence = np.zeros(len(images_list))  # Initialize sequence array
    
    # Select images based on segment_all flag
    if segment_all:
        combined_images = images_list
        combined_images_sequence = np.zeros(len(combined_images))  # All treated as training
    else:
        # Check if we have enough images
        if len(images_list) < train_n + validate_n:
            print("Not enough images in container for training and validation")
            raise ValueError(f"Need at least {train_n + validate_n} images but found {len(images_list)}")
            
        # Select random images for training and validation
        train_indices = np.random.choice(len(images_list), train_n, replace=False)
        train_images = [images_list[i] for i in train_indices]
        
        # Get validation images from the remaining ones
        validate_candidates = [img for i, img in enumerate(images_list) if i not in train_indices]
        validate_images = np.random.choice(validate_candidates, validate_n, replace=False)
        
        # Interleave the arrays and create sequence markers
        combined_images, combined_images_sequence = interleave_arrays(train_images, validate_images)
    
    # If resuming, filter out already processed images/patches
    processing_units = []  # Will contain tuples of (image, sequence_val, [metadata])
    
    if resume_from_table and len(df) > 0:
        # For patch mode, we need to check image_id + patch coordinates
        if use_patches:
            # Get list of already processed image+patch combinations
            processed_patches = set()
            for _, row in df[df['processed'] == True].iterrows():
                patch_key = (row['image_id'], row.get('patch_x', 0), row.get('patch_y', 0), 
                             row.get('patch_width', 0), row.get('patch_height', 0))
                processed_patches.add(patch_key)
            
            # Generate all possible patches
            for i, img in enumerate(combined_images):
                img_id = img.getId()
                seq_val = combined_images_sequence[i]
                
                # Get image dimensions
                size_x = img.getSizeX()
                size_y = img.getSizeY()
                
                # Generate patches for this image
                img_patches = generate_patch_coordinates(
                    size_x, size_y, patch_size, patches_per_image, random_patches)
                
                # Filter out already processed patches
                for patch in img_patches:
                    patch_key = (img_id, patch[0], patch[1], patch[2], patch[3])
                    if patch_key not in processed_patches:
                        processing_units.append((img, seq_val, patch))
                        
            print(f"Found {len(processing_units)} remaining patches to process")
            
        else:
            # Get list of already processed image IDs
            processed_ids = set(df[df['processed'] == True]['image_id'].values)
            
            # Filter combined_images
            for i, img in enumerate(combined_images):
                if img.getId() not in processed_ids:
                    processing_units.append((img, combined_images_sequence[i], None))
            
            print(f"Found {len(processing_units)} remaining images to process")
    else:
        # Not resuming, generate all processing units
        if use_patches:
            # Generate patches for all images
            for i, img in enumerate(combined_images):
                seq_val = combined_images_sequence[i]
                
                # Get image dimensions
                size_x = img.getSizeX()
                size_y = img.getSizeY()
                
                # Generate patches for this image
                img_patches = generate_patch_coordinates(
                    size_x, size_y, patch_size, patches_per_image, random_patches)
                
                for patch in img_patches:
                    processing_units.append((img, seq_val, patch))
                    
            print(f"Generated {len(processing_units)} patches to process")
        else:
            # Use full images
            for i, img in enumerate(combined_images):
                processing_units.append((img, combined_images_sequence[i], None))
    
    # Calculate total number of batches
    total_batches = (len(processing_units) + batch_size - 1) // batch_size
    
    if use_patches:
        print(f"Processing {len(processing_units)} patches in {total_batches} batches")
    else:
        print(f"Processing {len(processing_units)} images in {total_batches} batches")
    
    print(f"3D mode: {three_d}")
    
    # Process images/patches in batches
    for batch_idx in range(total_batches):
        print(f"\nProcessing batch {batch_idx+1}/{total_batches}")
        
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(processing_units))
        batch_units = processing_units[start_idx:end_idx]
        
        # Load batch images as dask arrays for lazy loading
        images = []
        dask_images = []
        image_data = []  # Store metadata about each image/patch
        
        for unit_idx, (image, seq_val, patch) in enumerate(batch_units):
            image_id = image.getId()
            
            # Determine which timepoint to use
            if timepoint_mode == "all":
                # Use all timepoints (not yet supported in this function)
                actual_timepoint = timepoints[0]  # Default to first timepoint for now
                print("Warning: 'all' timepoint mode not fully supported yet, using first timepoint")
            elif timepoint_mode == "random":
                # Select a random timepoint from the list
                actual_timepoint = np.random.choice(timepoints)
            else:  # "specific"
                # Use the first timepoint in the list
                actual_timepoint = timepoints[0]
            
            # For 3D mode or 2D with patches
            if three_d:
                # 3D mode - process entire Z-stack or specified z-range
                pixels = image.getPrimaryPixels()
                
                if patch is not None:
                    # Extract 3D patch (x, y, z-stack)
                    x, y, width, height = patch
                    img_3d = np.zeros((image.getSizeZ(), height, width), dtype=np.uint16)
                    
                    # Load each z-slice for the patch
                    for z in range(image.getSizeZ()):
                        full_plane = pixels.getPlane(z, channel, actual_timepoint)
                        img_3d[z] = full_plane[y:y+height, x:x+width]
                        
                    # Record metadata
                    image_data.append({
                        'image_id': image_id, 
                        'sequence': seq_val,
                        'timepoint': actual_timepoint,
                        'z_slice': 'all',
                        'is_patch': True,
                        'patch_x': x,
                        'patch_y': y,
                        'patch_width': width,
                        'patch_height': height
                    })
                else:
                    # Process full 3D volume
                    img_3d = np.stack([pixels.getPlane(z, channel, actual_timepoint) 
                                     for z in range(image.getSizeZ())])
                    
                    # Record metadata
                    image_data.append({
                        'image_id': image_id, 
                        'sequence': seq_val,
                        'timepoint': actual_timepoint,
                        'z_slice': 'all',
                        'is_patch': False,
                        'patch_x': 0,
                        'patch_y': 0,
                        'patch_width': image.getSizeX(),
                        'patch_height': image.getSizeY()
                    })
                
                images.append(img_3d)
                print(f"Loaded 3D image/patch for image {image_id} with shape {img_3d.shape}")
                
            else:
                # 2D mode - determine which z-slice to use
                if z_slice_mode == "all":
                    # Use all z-slices (not yet supported in this function)
                    actual_z_slice = z_slices[0]  # Default to first z-slice for now
                    print("Warning: 'all' z-slice mode not fully supported yet, using first z-slice")
                elif z_slice_mode == "random":
                    # Select a random z-slice from the list
                    actual_z_slice = np.random.choice(z_slices)
                else:  # "specific"
                    # Use the first z-slice in the list
                    actual_z_slice = z_slices[0]
                
                pixels = image.getPrimaryPixels()
                
                if patch is not None:
                    # Extract 2D patch from the specified plane
                    x, y, width, height = patch
                    full_plane = pixels.getPlane(actual_z_slice, channel, actual_timepoint)
                    img = full_plane[y:y+height, x:x+width]
                    
                    # Record metadata
                    image_data.append({
                        'image_id': image_id, 
                        'sequence': seq_val,
                        'timepoint': actual_timepoint,
                        'z_slice': actual_z_slice,
                        'is_patch': True,
                        'patch_x': x,
                        'patch_y': y,
                        'patch_width': width,
                        'patch_height': height
                    })
                else:
                    # Get full 2D plane
                    img = pixels.getPlane(actual_z_slice, channel, actual_timepoint)
                    
                    # Record metadata
                    image_data.append({
                        'image_id': image_id, 
                        'sequence': seq_val,
                        'timepoint': actual_timepoint,
                        'z_slice': actual_z_slice,
                        'is_patch': False,
                        'patch_x': 0,
                        'patch_y': 0,
                        'patch_width': image.getSizeX(),
                        'patch_height': image.getSizeY()
                    })
                
                images.append(img)
                print(f"Loaded 2D image/patch for image {image_id} with shape {img.shape}")
        
        # Process batch with SAM using standard numpy arrays
        print("Starting napari viewer with SAM annotator. Close the viewer window when done.")
        
        # Create viewer without context management
        viewer = napari.Viewer()
        
        # Add image series annotator
        image_series_annotator(
            images, 
            model_type=model_type,
            viewer=viewer,
            embedding_path=os.path.join(output_folder, "embed"),
            output_folder=os.path.join(output_folder, "output"),
            is_volumetric=three_d
        )
        
        # Start the napari application - this blocks until the viewer is closed
        try:
            napari.run()
            print("Napari viewer closed.")
        except KeyboardInterrupt:
            print("Napari viewer was interrupted. Processing results anyway...")
        except Exception as e:
            print(f"Error in napari: {e}")
            
        print("Processing results from batch...")
        print("Done annotating batch, storing results in zarr and uploading to OMERO")
        
        # Initialize batch progress tracking
        batch_completed = 0
        batch_skipped = 0
        
        # Process results for batch
        batch_df = pd.DataFrame(columns=df.columns)
        
        for n, unit_data in enumerate(image_data):
            local_n = n  # Index within current batch
            global_n = start_idx + n  # Global index across all batches
            
            # Get the image object
            image = conn.getObject("Image", unit_data['image_id'])
            is_patch = unit_data['is_patch']
            patch_info = None
            
            if is_patch:
                patch_info = (unit_data['patch_x'], unit_data['patch_y'], 
                             unit_data['patch_width'], unit_data['patch_height'])
            
            # Store segmentation mask in zarr before uploading to OMERO
            seg_file_path = os.path.join(output_folder, "output", f"seg_{local_n:05d}.tif")
            if not os.path.exists(seg_file_path):
                print(f"Warning: Segmentation file not found for image {image.getId()}, skipping")
                batch_skipped += 1
                
                # Add a row for skipped image but mark as not processed
                is_train = unit_data['sequence'] == 0 if not segment_all else True
                is_validate = unit_data['sequence'] == 1 if not segment_all else False
                
                # Z-slice information
                z_info = 'all' if three_d else unit_data['z_slice']
                
                new_row = pd.DataFrame([{
                    "image_id": image.getId(),
                    "image_name": image.getName(),
                    "train": is_train,
                    "validate": is_validate,
                    "channel": channel,
                    "z_slice": z_info,
                    "timepoint": unit_data['timepoint'],
                    "sam_model": model_type,
                    "embed_id": None,
                    "label_id": None,
                    "roi_id": None,
                    "is_volumetric": three_d,
                    "processed": False,
                    "is_patch": is_patch,
                    "patch_x": unit_data.get('patch_x', 0),
                    "patch_y": unit_data.get('patch_y', 0),
                    "patch_width": unit_data.get('patch_width', 0),
                    "patch_height": unit_data.get('patch_height', 0)
                }])
                batch_df = pd.concat([batch_df, new_row], ignore_index=True)
                continue
                
            batch_completed += 1
            
            # Read the segmentation mask
            mask_data = imageio.imread(seg_file_path)
            
            # Store in zarr format for efficient processing
            zarr_file_path = store_annotations_in_zarr(mask_data, zarr_path, global_n)
            
            # Store embedding in zarr format and zip for OMERO upload
            embed_zarr = f"embedding_{local_n:05d}.zarr"
            embed_dir = os.path.join(output_folder, "embed")
            zip_path = os.path.join(embed_dir, f"embedding_{global_n:05d}.zip")
            
            # Check if the embedding directory exists before trying to zip it
            embed_zarr_path = os.path.join(embed_dir, embed_zarr)
            if not os.path.exists(embed_zarr_path):
                print(f"Warning: Embedding directory {embed_zarr} not found, skipping embedding upload")
                embed_id = None
            else:
                with zipfile.ZipFile(zip_path, 'w') as zip_file:
                    zip_directory(embed_dir, embed_zarr, zip_file)
                
                # Upload embedding to OMERO
                embed_id = ezomero.post_file_annotation(
                    conn,
                    str(zip_path),
                    ns='microsam.embeddings',
                    object_type="Image",
                    object_id=image.getId(),
                    description=f'SAM embedding ({model_type}), 3D={three_d}, Patch={is_patch}'
                )
            
            # Convert zarr annotation to TIFF for OMERO compatibility
            tiff_path = os.path.join(output_folder, "output", f"seg_{global_n:05d}.tiff")
            zarr_to_tiff(zarr_file_path, tiff_path)
            
            # For ROI creation, we need to handle patches differently
            if is_patch:
                # We need to create ROIs with the proper offset in the original image
                patch_x, patch_y = patch_info
            else:
                patch_x, patch_y = 0, 0
                
            # Upload labels and create ROIs - handle 3D and patches
            if three_d:
                # For 3D data, handle z-dimension correctly
                z_for_roi = range(image.getSizeZ())
                label_id, roi_id = upload_rois_and_labels(
                    conn, 
                    image, 
                    tiff_path, 
                    z_for_roi,
                    channel, 
                    unit_data['timepoint'], 
                    model_type,
                    is_volumetric=True,
                    patch_offset=(patch_x, patch_y) if is_patch else None,
                    read_only_mode=read_only_mode,
                    local_output_dir=local_output_dir
                )
            else:
                # For 2D data - with potential patch offset
                label_id, roi_id = upload_rois_and_labels(
                    conn, 
                    image, 
                    tiff_path, 
                    unit_data['z_slice'], 
                    channel, 
                    unit_data['timepoint'], 
                    model_type,
                    is_volumetric=False,
                    patch_offset=(patch_x, patch_y) if is_patch else None,
                    read_only_mode=read_only_mode,
                    local_output_dir=local_output_dir
                )
            
            # Update tracking dataframe
            is_train = unit_data['sequence'] == 0 if not segment_all else True
            is_validate = unit_data['sequence'] == 1 if not segment_all else False
            
            # Z-slice information
            z_info = 'all' if three_d else unit_data['z_slice']
            
            new_row = pd.DataFrame([{
                "image_id": image.getId(),
                "image_name": image.getName(),
                "train": is_train,
                "validate": is_validate,
                "channel": channel,
                "z_slice": z_info,
                "timepoint": unit_data['timepoint'],
                "sam_model": model_type,
                "embed_id": embed_id,
                "label_id": label_id,
                "roi_id": roi_id,
                "is_volumetric": three_d,
                "processed": True,
                "is_patch": is_patch,
                "patch_x": unit_data.get('patch_x', 0),
                "patch_y": unit_data.get('patch_y', 0),
                "patch_width": unit_data.get('patch_width', 0),
                "patch_height": unit_data.get('patch_height', 0)
            }])
            batch_df = pd.concat([batch_df, new_row], ignore_index=True)
        
        # Update the main DataFrame with the batch results
        df = pd.concat([df, batch_df], ignore_index=True)
        
        # Upload batch tracking table to OMERO
        if table_id is not None:
            # Delete the existing table before creating a new one
            try:
                print(f"Deleting existing table with ID: {table_id}")
                # Get the file annotation object for the table
                ann = conn.getObject("FileAnnotation", table_id)
                if ann:
                    # Delete the file annotation (which contains the table)
                    conn.deleteObjects("FileAnnotation", [table_id], wait=True)
                    print(f"Existing table deleted successfully")
                else:
                    print(f"Warning: Could not find table with ID: {table_id}")
            except Exception as e:
                print(f"Warning: Could not delete existing table: {e}")
                # Continue anyway, as we'll create a new table
        
        # Prepare DataFrame for OMERO table: Convert potentially None/NaN ID columns to string
        df_for_omero = df.copy()
        id_columns_to_convert = ['embed_id', 'label_id', 'roi_id', 'schema_attachment_id']
        for col in id_columns_to_convert:
            if col in df_for_omero.columns: # Ensure column exists
                # Convert to string, handling potential float NaNs first if necessary
                df_for_omero[col] = df_for_omero[col].astype(str)


        # Create a new table with the updated data
        table_id = ezomero.post_table(
            conn, 
            object_type=container_type.capitalize(), 
            object_id=container_id, 
            table=df_for_omero, # Use the converted DataFrame
            title="micro_sam_training_data"
        )
        if table_id is None:
            print("Warning: Failed to create tracking table")
        else:
            print(f"Created new tracking table with ID: {table_id}")
        
        print(f"Batch {batch_idx+1}/{total_batches} results:")
        print(f"  - Completed: {batch_completed}/{len(batch_units)} units")
        print(f"  - Skipped: {batch_skipped}/{len(batch_units)} units")
        
        if batch_skipped > 0 and batch_idx < total_batches - 1:
            # Ask user if they want to continue with next batch or stop here
            try:
                response = input("Some units were skipped. Continue with next batch? (y/n): ")
                if response.lower() not in ['y', 'yes']:
                    print("Stopping processing at user request.")
                    break
            except:
                # In case of non-interactive environment, continue by default
                print("Non-interactive environment detected. Continuing with next batch.")
        
        # Clean up temporary files for this batch
        for n in range(batch_size):  # Use local indexing for cleanup
            if start_idx + n >= len(processing_units):  # Skip if we've processed all units
                continue
                
            embed_zip = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zip")
            embed_zarr = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zarr")
            seg_file = os.path.join(output_folder, "output", f"seg_{n:05d}.tif")
            
            for path in [embed_zip, seg_file]:
                if os.path.exists(path):
                    os.remove(path)
                    
            if os.path.exists(embed_zarr) and os.path.isdir(embed_zarr):
                shutil.rmtree(embed_zarr)
    
    # Final statistics
    total_processed = df[df['processed'] == True].shape[0]
    total_skipped = df[df['processed'] == False].shape[0]
    
    print(f"\nAll batches completed.")
    print(f"Total processed: {total_processed} units")
    print(f"Total skipped: {total_skipped} units")
    print(f"Final tracking table ID: {table_id} in {source_desc}")
    
    return table_id, combined_images

In [None]:
def organize_local_outputs(local_dir, container_type, container_id, image_id, timepoint=0, z_slice=0, is_patch=False, patch_coords=None):
    """
    Organize local storage for annotations when working with read-only OMERO servers
    
    Args:
        local_dir: Base directory for local storage
        container_type: Type of OMERO container ('dataset', 'plate', etc.)
        container_id: ID of the container
        image_id: ID of the image
        timepoint: Time point index
        z_slice: Z-slice index or 'all' for volumetric data
        is_patch: Whether this is a patch of a larger image
        patch_coords: Optional tuple of (x, y, width, height) for patch info
    
    Returns:
        dict: Dictionary with paths for various outputs
    """
    # Create the base container directory
    container_path = os.path.join(local_dir, f"{container_type}_{container_id}")
    os.makedirs(container_path, exist_ok=True)
    
    # Create image-specific directory
    image_path = os.path.join(container_path, f"image_{image_id}")
    os.makedirs(image_path, exist_ok=True)
    
    # Create subdirectories for different outputs
    embed_path = os.path.join(image_path, "embeddings")
    label_path = os.path.join(image_path, "labels")
    roi_path = os.path.join(image_path, "rois")
    
    for path in [embed_path, label_path, roi_path]:
        os.makedirs(path, exist_ok=True)
    
    # Determine file naming based on dimensionality and patch info
    name_parts = []
    
    # Add z-slice info
    if z_slice == 'all':
        name_parts.append("vol")  # Volumetric data
    else:
        name_parts.append(f"z{z_slice}")
    
    # Add timepoint info
    name_parts.append(f"t{timepoint}")
    
    # Add patch info if applicable
    if is_patch and patch_coords:
        x, y, width, height = patch_coords
        name_parts.append(f"patch_x{x}_y{y}_w{width}_h{height}")
    
    # Create base filename
    base_name = "_".join(name_parts)
    
    # Return paths for different output types
    return {
        "base_dir": image_path,
        "embedding_dir": embed_path,
        "embedding_path": os.path.join(embed_path, f"{base_name}_embedding.zip"),
        "label_path": os.path.join(label_path, f"{base_name}_label.tiff"),
        "roi_path": os.path.join(roi_path, f"{base_name}_rois.json"),
        "metadata_path": os.path.join(image_path, f"{base_name}_metadata.json"),
        "base_name": base_name
    }

def save_annotations_schema(metadata_path, image_id, label_path, roi_data, annotation_metadata):
    """
    Save annotation metadata and ROIs in a structured JSON schema for local storage
    
    Args:
        metadata_path: Path to save the metadata JSON file
        image_id: OMERO image ID
        label_path: Path to the label image file
        roi_data: List of ROI data extracted from label image
        annotation_metadata: Dictionary with additional metadata
    
    Returns:
        bool: True if successful, False otherwise
    """
    import json
    
    try:
        # Create the JSON schema
        schema = {
            "schema_version": "1.0",
            "created_at": pd.Timestamp.now().isoformat(),
            "image": {
                "id": int(image_id),
                "name": annotation_metadata.get("image_name", ""),
                "server": annotation_metadata.get("server", "")
            },
            "annotation": {
                "model": annotation_metadata.get("model_type", ""),
                "is_volumetric": annotation_metadata.get("is_volumetric", False),
                "channel": annotation_metadata.get("channel", 0),
                "z_slice": annotation_metadata.get("z_slice", 0),
                "timepoint": annotation_metadata.get("timepoint", 0),
                "is_patch": annotation_metadata.get("is_patch", False),
                "patch_coords": annotation_metadata.get("patch_coords", None)
            },
            "files": {
                "label_path": os.path.relpath(label_path, os.path.dirname(metadata_path)),
                "embedding_path": annotation_metadata.get("embedding_path", "")
            },
            "rois": roi_data
        }
        
        # Write the schema to file
        with open(metadata_path, 'w') as f:
            json.dump(schema, f, indent=2)
            
        print(f"Saved annotation schema to {metadata_path}")
        return True
        
    except Exception as e:
        print(f"Error saving annotation schema: {e}")
        return False

### Load images from OMERO and open in napari with micro-sam annotator

When using 3D mode (`three_d=True`), the notebook will process entire Z-stacks instead of single slices. This allows for volumetric annotation using micro-SAM's 3D capabilities.

When using patch mode (`use_patches=True`), the notebook will extract smaller regions from large images for more efficient annotation.

### Running segmentation in batch

Note: some warnings from napari are expected in the output here, generally not a problem

In [None]:
##input parameters
model_type = 'vit_b_lm'
segment_all = False
train_n = 2   
validate_n = 2
channel = 3  # which channel to segment starting from 0

# Z-slice handling (for 2D mode)
z_slices = [4, 6, 8]  # List of z-slices to use (ignored if three_d=True)
z_slice_mode = "random"  # Options: "all", "random", "specific" (uses z_slices[0])

# Timepoint handling
timepoints = [0]  # List of timepoints to use
timepoint_mode = "specific"  # Options: "all", "random", "specific" (uses timepoints[0])

# Patch extraction settings
use_patches = False  # Set to True to extract and process image patches instead of full images
patch_size = (512, 512)  # Size of patches to extract (width, height)
patches_per_image = 2  # Number of patches to extract from each image (if random_patches=True)
random_patches = True  # If True, extract random patches; if False, extract from image center

# Batch processing settings
batch_size = 2  # The number of images/patches to process at once in napari
three_d = False  # Set to True for 3D volumetric processing
resume_from_table = False  # Set to True to continue from a previous run

# Read-only mode settings
read_only_mode = False  # Set to True to save annotations locally instead of uploading to OMERO
local_output_dir = "./omero_annotations"  # Directory where annotations will be saved when in read-only mode

# Configure napari settings
settings = get_settings()
settings.application.ipy_interactive = False

# Run batch processing with dask lazy loading
# Get all images from the specified container
images_list, source_desc = get_images_from_container(conn, datatype, data_id)

if len(images_list) > 0:
    table_id, processed_images = process_omero_batch_with_dask(
        images_list=images_list,
        output_folder=output_directory,
        container_type=datatype,
        container_id=data_id,
        source_desc=source_desc,
        model_type=model_type,
        batch_size=batch_size,
        channel=channel,
        timepoints=timepoints,
        timepoint_mode=timepoint_mode,
        z_slices=z_slices,
        z_slice_mode=z_slice_mode,
        segment_all=segment_all,
        train_n=train_n,
        validate_n=validate_n,
        three_d=three_d,
        use_patches=use_patches,
        patch_size=patch_size,
        patches_per_image=patches_per_image,
        random_patches=random_patches,
        resume_from_table=resume_from_table,
        read_only_mode=read_only_mode,
        local_output_dir=local_output_dir
    )
    print(f"Finished processing with dask lazy loading. Table ID: {table_id}")
    print(f"To resume this session later, set resume_from_table=True")
else:
    print(f"No images found in the {datatype} with ID {data_id}")

## Resuming Annotation Sessions

This notebook now supports resuming annotation sessions. If you need to stop annotating and continue later:

1. Set `resume_from_table = True` in the parameters section
2. Run the notebook as usual
3. The system will automatically detect previously annotated images/patches and continue with the remaining ones

This is useful for:
- Long annotation sessions that need to be split over multiple days
- Cases where napari was closed accidentally
- Continuing after computer restarts or crashes

The tracking table in OMERO keeps track of which images/patches have been successfully processed.

## Examples for processing different scenarios

Here are examples showing how to set up the notebook for different use cases:

### Processing with multiple z-slices
```python
# Z-slice handling (for 2D mode)
z_slices = [4, 8, 12, 16]  # List of multiple z-slices to use
z_slice_mode = "random"  # Randomly select one z-slice from the list for each image
three_d = False  # Use 2D mode
```

### Processing with time series
```python
# Timepoint handling
timepoints = [0, 5, 10, 15]  # List of timepoints to process
timepoint_mode = "all"  # Process all timepoints in the list
```

### Processing image patches
```python
# Patch extraction settings
use_patches = True  # Enable patch extraction
patch_size = (512, 512)  # Size of patches to extract
patches_per_image = 3  # Extract multiple patches per image
random_patches = True  # Extract patches from random locations
```

### Processing a 3D dataset
```python
# 3D settings
three_d = True  # Enable 3D volumetric processing
batch_size = 2  # Smaller batch size for 3D as it requires more memory
```

### Processing a large dataset in parts
```python
# For large datasets
segment_all = False  # Don't process all images
train_n = 10  # Process only 10 images for training
validate_n = 5  # And 5 for validation
resume_from_table = True  # Enable resuming from previous sessions
```

## Using Read-Only Mode

This notebook now supports read-only mode for working with OMERO servers where you have read-only access, such as the Image Data Resource (IDR).

When working in read-only mode:
1. All annotations (ROIs, segmentation masks, embeddings) are saved locally instead of being uploaded to OMERO
2. Data is organized in a structured way on disk for easy access and further processing
3. No write permissions are required on the OMERO server

To use read-only mode:
```python
# Read-only mode settings
read_only_mode = True  # Enable read-only mode
local_output_dir = "./omero_annotations"  # Directory where annotations will be saved
```

The annotations will be organized in a directory structure as follows:
```
local_output_dir/
  ├── image_{image_id}/
  │    ├── roi_{filename}.json  # ROI metadata in JSON format
  │    ├── {filename}.tiff      # Segmentation mask
  │    └── embeddings/          # SAM embeddings
  │         └── embedding_{n}.zip
  ├── image_{another_image_id}/
  │    ├── ...
```

The JSON metadata files contain information about the ROIs, including:
- Image ID and name
- Model type used
- Z-slice, channel, and timepoint information
- Whether the annotation is volumetric
- Patch offset (if using patch-based processing)

### Working with a read-only OMERO server

```python
##input parameters
model_type = 'vit_b'
segment_all = False
train_n = 2   
validate_n = 2
channel = 3  # which channel to segment

# Z-slice handling
z_slices = [4, 6, 8]  
z_slice_mode = "random"  

# Read-only mode settings
read_only_mode = True  # Enable read-only mode
local_output_dir = "./omero_idr_annotations"  # Save annotations locally

# Other settings
batch_size = 2
three_d = False
use_patches = True
patch_size = (512, 512)
patches_per_image = 2
random_patches = True
resume_from_table = False

# Run batch processing
images_list, source_desc = get_images_from_container(conn, datatype, data_id)
if len(images_list) > 0:
    table_id, processed_images = process_omero_batch_with_dask(
        images_list=images_list,
        output_folder=output_directory,
        container_type=datatype,
        container_id=data_id,
        source_desc=source_desc,
        model_type=model_type,
        batch_size=batch_size,
        channel=channel,
        timepoints=timepoints,
        timepoint_mode=timepoint_mode,
        z_slices=z_slices,
        z_slice_mode=z_slice_mode,
        segment_all=segment_all,
        train_n=train_n,
        validate_n=validate_n,
        three_d=three_d,
        use_patches=use_patches,
        patch_size=patch_size,
        patches_per_image=patches_per_image,
        random_patches=random_patches,
        resume_from_table=resume_from_table,
        read_only_mode=read_only_mode,
        local_output_dir=local_output_dir
    )
```