### Fine tuning SAM with OMERO data using a batch approach - Enhanced Version

### Features
- Supports multiple OMERO data types (single images, datasets, projects, plates, and screens)
- Batch processing with micro-SAM for segmentation
- Stores all annotations in OMERO as ROIs and attachments
- Uses dask for lazy loading of images for better memory management
- Supports 3D volumetric segmentation for z-stacks
- **NEW**: Support for multiple z-slices in 2D mode
- **NEW**: Support for time series analysis
- **NEW**: Support for patch-based extraction and annotation
- **NEW**: Improved resumption of annotation sessions

### TODOs
- Store all annotations into OMERO, see: https://github.com/computational-cell-analytics/micro-sam/issues/445; in series annotator possible to add commit path with prompts, but they get overwritten
- Clean up the errors and warnings output from napari
- Improve ROI creation for 3D volumes to better represent volumetric masks in OMERO
- Work with Dask arrays directly in micro-sam
- Add recovery mode to handle cases when users abort in the middle of a batch annotation session (currently annotations made before closing are preserved, but could be improved with a dedicated recovery workflow)

Instructions:
  - To make it easier to run with OMERO and to not expose login and passwords, password is stored in .env file (see example .env_example). Still it is not recommended to save credentials unencrypted hence a better solution will be worked on.
  - This notebook supports processing images from various OMERO container types: images, datasets, projects, plates, and screens.
  - Specify the container type in the `datatype` variable and the container ID in the `data_id` variable.
  - You can choose to segment all images in the container or select a random subset for training and validation.
  - **NEW**: You can now specify multiple z-slices and timepoints to analyze.
  - **NEW**: You can extract and analyze patches from large images.

In [None]:
# OMERO-related imports
import omero
from omero.gateway import BlitzGateway
import ezomero

# Scientific computing and image processing
import numpy as np
import pandas as pd

# File and system operations
import os
import shutil
import tempfile
import warnings
from dotenv import load_dotenv

# Import our custom modules
from src.omero_functions import (
    print_object_details, 
    get_images_from_container, 
    get_dask_image, 
    upload_rois_and_labels
)
from src.file_io_functions import (
    zip_directory, 
    store_annotations_in_zarr, 
    zarr_to_tiff,
    cleanup_local_embeddings,
    organize_local_outputs,
    save_annotations_schema
)
from src.image_functions import (
    label_to_rois,
    generate_patch_coordinates,
    extract_patch
)
from src.utils import NumpyEncoder, interleave_arrays
from src.processing_pipeline import process_omero_batch_with_dask

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)
conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), group=os.environ.get("GROUP"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "screen", "plate", "project", "dataset", "image"
data_id = 1112

# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")
    
elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")
    
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")
    
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")

### Create temporary folder to store training data, this will be uploaded to OMERO later

In [None]:
output_directory = os.path.normcase(tempfile.mkdtemp())
print('Output Directory: ', output_directory)

### Start batch annotation

In [None]:
# Get all images from the specified container
images, source_desc = get_images_from_container(conn, datatype, data_id)

# If we need to filter by image size
# images = [img for img in images if img.getSizeX() > 1000 and img.getSizeY() > 1000]
# print(f"Filtered to {len(images)} images with size > 1000x1000")

# Configuration for batch processing
segment_all = True  # Process all images in the dataset?
train_n = 2  # Number of training images (if not segment_all)
validate_n = 1  # Number of validation images (if not segment_all)
model_type = 'vit_l'  # SAM model to use
batch_size = 2  # Number of images to process at once
channel = 3  # Channel to segment (usually the nuclear/main stain channel)
three_d = False  # Use 3D mode?
z_slices = [0]  # Which z-slices to process (if not three_d)
z_slice_mode = "specific"  # "all", "random", or "specific"
timepoints = [0]  # Which timepoints to process
timepoint_mode = "specific"  # "all", "random", or "specific"
resume_from_table = False  # Resume from an existing tracking table?

# Patch extraction settings
use_patches = False  # Extract and process patches instead of full images?
patch_size = (512, 512)  # Size of patches to extract
patches_per_image = 1  # Number of patches to extract per image
random_patches = True  # Extract random patches or centered patches?

# Optional read-only mode (for OMERO servers where you don't have write permissions)
read_only_mode = False  # Save annotations locally instead of uploading to OMERO
local_output_dir = "./omero_annotations"  # Directory to save local annotations

if read_only_mode:
    os.makedirs(local_output_dir, exist_ok=True)
    
# Process the images
table_id, combined_images = process_omero_batch_with_dask(
    conn,
    images,
    output_directory,
    datatype,
    data_id,
    source_desc,
    model_type=model_type,
    batch_size=batch_size,
    channel=channel,
    timepoints=timepoints,
    timepoint_mode=timepoint_mode,
    z_slices=z_slices,
    z_slice_mode=z_slice_mode,
    segment_all=segment_all,
    train_n=train_n,
    validate_n=validate_n,
    three_d=three_d,
    use_patches=use_patches,
    patch_size=patch_size,
    patches_per_image=patches_per_image,
    random_patches=random_patches,
    resume_from_table=resume_from_table,
    read_only_mode=read_only_mode,
    local_output_dir=local_output_dir
)

print(f"Annotation complete! Table ID: {table_id}")
print(f"Processed {len(combined_images)} images")

### View annotations tracking table

In [None]:
# Retrieve and display the tracking table
if table_id is not None:
    tracking_df = ezomero.get_table(conn, table_id)
    print(f"Tracking table contains {len(tracking_df)} rows")
    display(tracking_df)
else:
    print("No tracking table was created")

### Clean up and close connection

In [None]:
# Clean up temporary directory
try:
    shutil.rmtree(output_directory)
    print(f"Removed temporary directory: {output_directory}")
except Exception as e:
    print(f"Error removing temporary directory: {e}")

# Close OMERO connection
conn.close()
print("OMERO connection closed")