### Fine tuning SAM with OMERO data using a batch approach

TO DO
- make it work for different OMERO data types (single images, plates, etc)
- store all annotations into OMERO, see: https://github.com/computational-cell-analytics/micro-sam/issues/445; in series annotator possible to add commit path with prompts, but they get overwritten
- clean up the errors and warnings output from napari

Instructions:
  - To make it easier to run with OMERO and to not expose login and passwords password is stored in .env file (see example .env_example) . Still it is not recommended to save credentials unencrypted hence a better solution will be worked on.

In [None]:
# OMERO-related imports
import omero
from omero.gateway import BlitzGateway
import ezomero

# Scientific computing and image processing
import cv2
import imageio.v3 as imageio
import numpy as np
import pandas as pd

# File and system operations
import os
import shutil
import tempfile
import zipfile
import warnings
from tifffile import imwrite
from dotenv import load_dotenv

# Micro-SAM and Napari
from napari.settings import get_settings
from micro_sam.sam_annotator import image_series_annotator, annotator_2d
from micro_sam.util import precompute_image_embeddings
import napari

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)
conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 24143
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    dataset = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    dataset = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

### Create temporary folder to store training data, this will be uploaded to OMERO later

In [None]:
output_directory = os.path.normcase(tempfile.mkdtemp())
print('Output Directory: ', output_directory)

In [None]:

def zip_directory(source_path, zarr_path, zip_file):
    """Zip a directory while handling null characters in paths."""
    for root, dirs, files in os.walk(zarr_path):
        for file in files:
            try:
                # Create paths
                full_path = os.path.join(root, file)
                rel_path = os.path.relpath(full_path, source_path)
                
                # Remove null characters while preserving the path structure
                safe_full_path = full_path.replace('\x00', '')
                safe_rel_path = rel_path.replace('\x00', '')
                
                # Add file to zip if it exists
                if os.path.exists(safe_full_path):
                    zip_file.write(safe_full_path, safe_rel_path)
            except Exception as e:
                print(f"Warning: Error processing {file}: {str(e)}")
                continue

def interleave_arrays(train_images, validate_images):
    """
    Interleave two arrays of images in the pattern: train[0], validate[0], train[1], validate[1], ...
    If arrays are of unequal length, remaining elements are appended at the end.
    """
    # Create empty list to store interleaved images
    interleaved = []
    sequence = []
    # Get the length of the longer array
    max_len = max(len(train_images), len(validate_images))
    
    # Interleave the arrays
    for i in range(max_len):
        # Add train image if available
        if i < len(train_images):
            interleaved.append(train_images[i])
            sequence.append(0)
        # Add validate image if available
        if i < len(validate_images):
            interleaved.append(validate_images[i])
            sequence.append(1)
    
    return np.array(interleaved), np.array(sequence)


def mask_to_contour(mask):
    """Converts a binary mask to a list of ROI coordinates.

    Args:
        mask (np.ndarray): binary mask

    Returns:
        list: list of ROI coordinates
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def label_to_rois(label_img, z_slice, channel, timepoint):
    """
    Convert a 2D label image to OMERO ROI shapes
    
    Args:
        label_img (np.ndarray): 2D labeled image
        z_slice (int): Z-slice index
        channel (int): Channel index
        timepoint (int): Time point index
    
    Returns:
        list: List of OMERO shape objects
    """
    shapes = []
    unique_labels = np.unique(label_img)
    
    # Skip background (label 0)
    for label in unique_labels[1:]:
        # Create binary mask for this label
        mask = (label_img == label).astype(np.uint8)
        
        # Get contours
        contours = mask_to_contour(mask)
        
        # Convert each contour to polygon ROI
        for contour in contours:
            contour = contour[:, 0, :]  # Reshape to (N, 2)
            # Create polygon without text parameter
            poly = ezomero.rois.Polygon(
                points=contour,  # explicitly name the points parameter
                z=z_slice,
                c=channel,
                t=timepoint
            )
            shapes.append(poly)
    
    return shapes

def upload_rois_and_labels(conn, image, label_file, z_slice, channel, timepoint, model_type):
    """
    Upload both label map and ROIs for a segmented image
    
    Args:
        conn: OMERO connection
        image: OMERO image object
        label_file: Path to the label image file
        z_slice: Z-slice index
        channel: Channel index
        timepoint: Time point index
        model_type: SAM model type used
    
    Returns:
        tuple: (label_id, roi_id)
    """
    # Upload label map as attachment
    label_id = ezomero.post_file_annotation(
        conn,
        str(label_file),
        ns='microsam.labelimage',
        object_type="Image",
        object_id=image.getId(),
        description=f'SAM segmentation ({model_type})'
    )
    
    # Create ROIs from label image
    label_img = imageio.imread(label_file)
    shapes = label_to_rois(label_img, z_slice, channel, timepoint)
    
    if shapes:  # Only create ROI if shapes were found
        roi_id = ezomero.post_roi(
            conn,
            image.getId(),
            shapes,
            name=f'SAM_{model_type}',
            description=f'Segmentation using SAM model {model_type}'
        )
    else:
        roi_id = None
        
    return label_id, roi_id

def process_omero_batch(
    dataset: int = None,
    datatype: str = "dataset",
    output_folder: str = None,
    model_type: str = 'vit_l',
    batch_size: int = 3,
    channel: int = 0,
    timepoint: int = 0,
    z_slice: int = 0,
    segment_all: bool = True,
    train_n: int = 3,
    validate_n: int = 3,
):

    """
    Process OMERO dataset in batches for SAM segmentation and upload results back to OMERO
    
    Args:
        dataset: OMERO dataset object
        output_folder: Path to store temporary files
        model_type: SAM model type
        batch_size: Number of images to process at once
        channel: Channel to segment
        timepoint: Timepoint to process
        z_slice: Z-slice to process
        segment_all: segment all images in the dataset or only 
        
    """
    # Setup output directories
    output_path = os.path.join(output_folder, "output")
    embed_path = os.path.join(output_folder, "embed")
    
    # Remove directories if they exist
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    if os.path.exists(embed_path):
        shutil.rmtree(embed_path)
        
    # Create fresh directories
    os.makedirs(output_path)
    os.makedirs(embed_path)
    # Get image IDs based on datatype
    if datatype == "plate":
        images_dataset_ids = ezomero.get_image_ids(conn, plate=data_id)
    elif datatype == "dataset":
        images_dataset_ids = ezomero.get_image_ids(conn, dataset=data_id)
    elif datatype == "image":
        images_dataset_ids = [dataset]
    elif datatype == "project":
        images_dataset_ids = ezomero.get_image_ids(conn, project=data_id)

    if segment_all:
        selected_image_ids = images_dataset_ids
    else:
        if len(images_dataset_ids) < train_n + validate_n:
            print("Not enough images in dataset for training and validation")
            assert False
        
        # Select IDs for training and validation
        train_ids = np.random.choice(images_dataset_ids, train_n, replace=False)
        validate_ids = np.random.choice([x for x in images_dataset_ids if x not in train_ids], validate_n, replace=False)
        combined_image_ids, _ = interleave_arrays(train_ids, validate_ids)
        selected_image_ids = combined_image_ids
    if batch_size >  len(selected_image_ids):
        batch_size = len(selected_image_ids)
        total_batches = 1
    else:
        total_batches = (len(selected_image_ids) + batch_size - 1) // batch_size
    df = pd.DataFrame(columns=[
        "image_id", "image_name", "train", "validate", 
        "channel", "timepoint", "sam_model", "embed_id", "label_id", "roi_id"
    ])

    # Process images in batches
    print(f"\nStarting processing of {len(selected_image_ids)} images in batches of {batch_size}")
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(selected_image_ids))
        batch_image_ids = selected_image_ids[start_idx:end_idx]
            
        print(f"\n{'='*50}")
        print(f"Processing batch {batch_idx + 1} of {total_batches}")
        print(f"Current batch contains {len(batch_image_ids)} images")
        print(f"{'='*50}")
        
        # Load batch images only when needed
        print("Loading images from OMERO...")
        batch_images = [conn.getObject("Image", id) for id in batch_image_ids]
        images = []
        for image in batch_images:
            pixels = image.getPrimaryPixels()
            img = pixels.getPlane(z_slice, channel, timepoint)
            images.append(img)
        
        # Process batch with SAM
        print("\nOpening napari viewer for segmentation...")
        print("Please complete your annotations and close the viewer when done")
        viewer = napari.Viewer()
        image_series_annotator(
            images, 
            model_type=model_type,
            viewer=viewer,
            embedding_path=os.path.join(output_folder, "embed"),
            output_folder=os.path.join(output_folder, "output")
        )
        
        napari.run()
        print("Done annotating batch, uploading to OMERO now")
        # Upload results for batch
        print("\nUploading results to OMERO...")
        for n, image in enumerate(batch_images):
            local_n = n
            global_n = start_idx + n
            
            # Upload embedding
            embed_zarr = f"embedding_{local_n:05d}.zarr"
            embed_path = os.path.join(output_folder, "embed")
            zip_path = os.path.join(output_folder, "embed", f"embedding_{local_n:05d}.zip")
            
            with zipfile.ZipFile(zip_path, 'w') as zip_file:
                zip_directory(embed_path, embed_zarr, zip_file)
            
            embed_id = ezomero.post_file_annotation(
                conn,
                str(zip_path),
                ns='microsam.embeddings',
                object_type="Image",
                object_id=image.getId(),
                description=f'SAM embedding ({model_type})'
            )
            
            # Upload labels and ROIs
            label_file = os.path.join(output_folder, "output", f"seg_{local_n:05d}.tif")
            label_id, roi_id = upload_rois_and_labels(
                conn, 
                image, 
                label_file, 
                z_slice, 
                channel, 
                timepoint, 
                model_type
            )
            
            # Update tracking dataframe
            new_row = pd.DataFrame([{
                "image_id": image.getId(),
                "image_name": image.getName(),
                "train": global_n % 2 == 0,
                "validate": global_n % 2 == 1,
                "channel": channel,
                "z_slice": z_slice,
                "timepoint": timepoint,
                "sam_model": model_type,
                "embed_id": embed_id,
                "label_id": label_id,
                "roi_id": roi_id
            }])
            df = pd.concat([df, new_row], ignore_index=True)
        
        # Clean up batch files
        print("Cleaning up temporary files...")
        for n in range(batch_size):  # Use local indexing for cleanup
            embed_zip = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zip")
            embed_zarr = os.path.join(output_folder, "embed", f"embedding_{n:05d}.zarr")
            seg_file = os.path.join(output_folder, "output", f"seg_{n:05d}.tif")
            
            if os.path.exists(embed_zip):
                os.remove(embed_zip)
            if os.path.exists(embed_zarr):
                shutil.rmtree(embed_zarr)
            if os.path.exists(seg_file):
                os.remove(seg_file)
        print(f"\nCompleted batch {batch_idx + 1}")

    # Upload final tracking table
    table_id = ezomero.post_table(
        conn, 
        object_type="Dataset", 
        object_id=dataset.getId(), 
        table=df,
        title="micro_sam_training_data"
    )
    print("\nAll batches processed successfully!")
    print(f"Created table with ID: {table_id}")
    return table_id


### Load images from OMERO and open in napari with micro-sam annotator

### save model


### Running segmentation in batch

Note: some warnings from napari are expected in the output here, generally not a problem

In [None]:
##imput parameters
model_type = 'vit_b'
segment_all = False
train_n = 20   
validate_n = 20
channel = 3 #which channel to segment starting from 0
timepoint = 0
z_slice = 4 #for now pick one slice but TODO add option to pick multiple slices by giving a list of z slices, or random slices
batch_size = 10 # the number of images to process at once in napari

#set napari settings
settings = get_settings()
settings.application.ipy_interactive = False

# Usage
if datatype == "dataset":
    settings = get_settings()
    settings.application.ipy_interactive = False
    
    table_id = process_omero_batch(
        dataset=dataset,
        output_folder=output_directory,
        model_type=model_type,
        batch_size=batch_size,
        channel=channel,
        timepoint=timepoint,
        z_slice=z_slice,
        segment_all=segment_all,
        train_n=train_n,
        validate_n=validate_n,
    )
    print("Table ID:", table_id)