### Annotation of OMERO data using napari-micro-sam

## Instructions:
  - To run with OMERO and to not expose login in the notebook, the username is stored in .env file (see example .env_example). The password needs to be typed everytime.
    Still it is not recommended to save credentials unencrypted, hence a better solution will be worked on.
  - This notebook supports processing images from various OMERO object types: images, datasets, projects, plates, and screens.
  - Specify the container type in the `datatype` variable and the object ID in the `data_id` variable.
  - You can choose to segment all images in the container or select a random subset for training and validation.
    - You can now specify multiple z-slices and timepoints to analyze.
    - You can extract and analyze patches from large images, to reduce the number of object to annotated and still annotate more images.

### TODOs
See [TODO.md](./TODO.md) for the complete list of planned improvements and features.

## Load all required packages and dependencies

In [None]:
# OMERO-related imports
import omero
from omero.gateway import BlitzGateway
import ezomero
from napari.settings import get_settings

# Scientific computing and image processing
import numpy as np
import pandas as pd

# File and system operations
import os
import shutil
import tempfile
import warnings
from dotenv import load_dotenv

import importlib
# Reload specific modules
import src.omero_functions
import src.file_io_functions
import src.image_functions
import src.utils
import src.processing_pipeline

importlib.reload(src.omero_functions)
importlib.reload(src.file_io_functions)
importlib.reload(src.image_functions)
importlib.reload(src.utils)
importlib.reload(src.processing_pipeline)

# Re-import after reloading
from src.omero_functions import print_object_details, get_images_from_container, get_dask_image, upload_rois_and_labels, initialize_tracking_table, update_tracking_table_rows
from src.file_io_functions import zip_directory, store_annotations_in_zarr, zarr_to_tiff, cleanup_local_embeddings, organize_local_outputs, save_annotations_schema
from src.image_functions import label_to_rois, generate_patch_coordinates, extract_patch
from src.utils import NumpyEncoder, interleave_arrays
from src.processing_pipeline import process_omero_batch
import sys

# Add auto-reload capability for src module
import importlib.util

# Force reload of main src package if it's already loaded
if "src" in sys.modules:
    importlib.reload(sys.modules["src"])

# Define a helper function to reload modules more concisely
def reload_module(module_name):
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
    return __import__(module_name)

# Reload all src submodules
src_modules = [
    "src.omero_functions",
    "src.file_io_functions",
    "src.image_functions", 
    "src.utils",
    "src.processing_pipeline"
]

for module in src_modules:
    reload_module(module)

# Re-import after reloading to ensure we have the latest versions

get_settings().application.ipy_interactive = False


output_directory = os.path.normcase(tempfile.mkdtemp())
print('Created temporary work directory: ', output_directory)

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)
# Ask for password if not set
if not os.environ.get("PASSWORD"):
    from getpass import getpass
    os.environ["PASSWORD"] = getpass("Enter OMERO server password: ")

conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), group=os.environ.get("GROUP"), secure=True)

connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Select your dataset and check its content

In [None]:
datatype = "dataset" # "screen", "plate", "project", "dataset", "image"
data_id = 1112

# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")
    
elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")
    
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")
    
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")

# Check for any training tables already present in the dataset

### Choose a training data set name
Use a specific name if you want to resume from an existing table  
Or use the datetime format for a new training set

In [None]:
# Set a name for the training set
# Use a specific name if you want to resume from an existing table
# Or use the datetime format for a new training set
trainingset_name = "training_data_20240529"  # Use a fixed name if resuming from an existing table
# trainingset_name = "training_data_" + pd.Timestamp.now().strftime("%Y%m%d_%H%M")
print('Training Set Name: ', trainingset_name)

### Start batch annotation with optimized table management

In [None]:
# Get all images from the specified container
images, source_desc = get_images_from_container(conn, datatype, data_id)

# Configuration for batch processing
segment_all = False  # Process all images in the dataset?
train_n = 3  # Number of training images (if not segment_all)
validate_n = 3  # Number of validation images (if not segment_all)
model_type = 'vit_l_lm'  # SAM model to use
batch_size = 3  # Number of images to process at once
channel = 3  # Channel to segment (usually the nuclear/main stain channel), starting from 0
three_d = False  # Use 3D mode?
z_slices = [0,1,2,3,4,5,6,7,8,9,10]  # Which z-slices to process (if not three_d)
z_slice_mode = "random"  # "all", "random", or "specific"
timepoints = [0]  # Which timepoints to process
timepoint_mode = "specific"  # "all", "random", or "specific"
resume_from_table = True  # Set to True to resume from an existing tracking table

# Patch extraction settings
use_patches = False  # Extract and process patches instead of full images?
patch_size = (256, 256)  # Size of patches to extract
patches_per_image = 1  # Number of patches to extract per image
random_patches = True  # Extract random patches or centered patches?

# Optional read-only mode (for OMERO servers where you don't have write permissions)
read_only_mode = False  # Save annotations locally instead of uploading to OMERO
local_output_dir = "./omero_annotations"  # Directory to save local annotations

if read_only_mode:
    # Use trainingset_name in local output directory if provided
    if trainingset_name:
        local_output_dir = f"./omero_annotations_{trainingset_name}"
    os.makedirs(local_output_dir, exist_ok=True)
    
# Summarize the configuration for the user
print(f"Configuration Summary:")
print(f"  - Segment All Images: {segment_all}")
print(f"  - Training Images: {train_n}")
print(f"  - Validation Images: {validate_n}")
print(f"  - Model Type: {model_type}")
print(f"  - Batch Size: {batch_size}")
print(f"  - Channel: {channel}")
print(f"  - 3D Mode: {three_d}")
print(f"  - Z-Slices: {z_slices} (Mode: {z_slice_mode})")
print(f"  - Timepoints: {timepoints} (Mode: {timepoint_mode})")
print(f"  - Use Patches: {use_patches}")
if(use_patches):
    print(f"  - Patch Size: {patch_size}")
    print(f"  - Patches per Image: {patches_per_image}")
    print(f"  - Random Patches: {random_patches}")
print(f"  - Read-Only Mode: {read_only_mode}")


### Run the annotation routine

In [None]:
table_id, combined_images = process_omero_batch(
    conn,
    images,
    output_directory,
    datatype,
    data_id,
    source_desc,
    model_type=model_type,
    batch_size=batch_size,
    channel=channel,
    timepoints=timepoints,
    timepoint_mode=timepoint_mode,
    z_slices=z_slices,
    z_slice_mode=z_slice_mode,
    segment_all=segment_all,
    train_n=train_n,
    validate_n=validate_n,
    three_d=three_d,
    use_patches=use_patches,
    patch_size=patch_size,
    patches_per_image=patches_per_image,
    random_patches=random_patches,
    resume_from_table=resume_from_table,
    read_only_mode=read_only_mode,
    local_output_dir=local_output_dir,
    trainingset_name=trainingset_name
)

print(f"Annotation complete! Table ID: {table_id}")
print(f"Processed {len(combined_images)} images")

### Display annotations tracking table

In [None]:
# Retrieve and display the tracking table
if table_id is not None:
    tracking_df = ezomero.get_table(conn, table_id)
    print(f"Tracking table contains {len(tracking_df)} rows")
    display(tracking_df)
else:
    print("No tracking table was created")

### Clean up and close connection

In [None]:
# Clean up temporary directory
try:
    shutil.rmtree(output_directory)
    print(f"Removed temporary directory: {output_directory}")
except Exception as e:
    print(f"Error removing temporary directory: {e}")

# Close OMERO connection
conn.close()
print("OMERO connection closed")