## Training

In [26]:
from omero.gateway import BlitzGateway
import ezomero
#load dotenv for OMERO login
from dotenv import load_dotenv

from tifffile import imsave, imwrite, imread
import torch

import os
import tempfile
import pandas as pd
import zipfile
import numpy as np
import datetime
import shutil

from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.util.util import get_random_colors

import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation


import importlib
import sys

# Reload all src submodules
src_modules = [
    "src.omero_functions",
    "src.file_io_functions",
    "src.image_functions",
    "src.utils",
    "src.processing_pipeline",
]


def reload_module(module_name):
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
    return __import__(module_name)


for module in src_modules:
    reload_module(module)

# Re-import after reloading to ensure we have the latest versions
from src.omero_functions import (
    print_object_details,
    get_images_from_container,
    get_dask_image,
    upload_rois_and_labels,
    initialize_tracking_table,
    update_tracking_table_rows,
    get_dask_image_multiple,
    get_dask_dimensions,
)
from src.file_io_functions import (
    zip_directory,
    store_annotations_in_zarr,
    zarr_to_tiff,
    cleanup_local_embeddings,
    organize_local_outputs,
    save_annotations_schema,
)
from src.image_functions import label_to_rois, generate_patch_coordinates, extract_patch
from src.utils import NumpyEncoder, interleave_arrays
from src.processing_pipeline import process_omero_batch

from napari.settings import get_settings

get_settings().application.ipy_interactive = False



### Setup connection with OMERO

In [27]:
load_dotenv(override=True)
# Ask for password if not set
if not os.environ.get("PASSWORD"):
    from getpass import getpass

    os.environ["PASSWORD"] = getpass("Enter OMERO server password: ")

conn = BlitzGateway(
    host=os.environ.get("HOST"),
    username=os.environ.get("USER_NAME"),
    passwd=os.environ.get("PASSWORD"),
    group=os.environ.get("GROUP"),
    secure=True,
)

connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)


INFO:omero.gateway:created connection (uuid=3f06f483-a349-4350-9782-d3fcaf74d3c6)
INFO:omero.util.Resources:Starting
INFO:omero.util.Resources:Starting


Connected to OMERO Server


### Get info from the dataset

In [28]:
datatype = "project"  # "screen", "plate", "project", "dataset", "image"
data_id = 101
trainingset_name = "micro_sam_training_data_20240602"


# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")

elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")

elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")

elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")




Project Details:
- Name: Senescence
- ID: 101
- Owner: root root
- Group: system
- Number of datasets: 3
- Total images: 6


### Define output folders for training

In [29]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = os.path.expanduser("~")
models_dir = os.path.join(home_dir, "micro-sam_models")
os.makedirs(models_dir, exist_ok=True)
folder_name = f"micro-sam-{timestamp}"
output_directory = os.path.join(models_dir, folder_name)
os.makedirs(output_directory, exist_ok=True)
output_directory = os.path.abspath(output_directory)
#output_directory = os.path.abspath("C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503")
print(f"Output directory: {output_directory}")

Output directory: C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234


### Collecting data from OMERO using the attached table

In [31]:

def get_specific_table(conn, datatype, dataset_id, table_name="micro_sam_training_data"):
    """
    Find and return a specific table attached to a dataset by its name.
    
    Args:
        conn: OMERO connection
        datatype: Type of the dataset (e.g., "project", "dataset")
        dataset_id: ID of the dataset to search
        table_name: Name of the table file to find
        
    Returns:
        table: Table data as pandas DataFrame or list of lists
        file_ann_id: ID of the file annotation containing the table
    """
    # Get all file annotations on the dataset
    file_ann_ids = ezomero.get_file_annotation_ids(conn, datatype, dataset_id)
    
    # Get original file details to check names
    for ann_id in file_ann_ids:
        ann = conn.getObject("FileAnnotation", ann_id)
        if ann is None:
            continue
            
        orig_file = ann.getFile()
        if orig_file.getName() == table_name:
            try:
                table = ezomero.get_table(conn, ann_id)
                return table, ann_id
            except Exception as e:
                print(f"Found file {table_name} but failed to load as table: {e}")
                continue
                
    return None, None

In [32]:
table, file_ann_id = get_specific_table(conn, datatype, data_id, trainingset_name)
if table is not None:
    print(f"Found table {trainingset_name} in file annotation {file_ann_id}")
    # If pandas DataFrame:
    print(table.head())
else:
    print(f"No table named {trainingset_name} found")

Found table micro_sam_training_data_20240602 in file annotation 344
   image_id  channel  z_slice  timepoint  patch_x  patch_y  patch_width  \
0       254        1        0         10      367      364          256   
1       254        1        0         10     1324      956          256   
2       253        1        0         11     1517      488          256   
3       253        1        0         12      432      912          256   
4       256        1        0          1     1586     1706          256   

   patch_height  image_name sam_model embed_id label_id roi_id  \
0           256  r01c03.tif  vit_b_lm     None      320      1   
1           256  r01c03.tif  vit_b_lm      321      322      2   
2           256  r01c04.tif  vit_b_lm      323      324      3   
3           256  r01c04.tif  vit_b_lm      325      326      4   
4           256  r01c01.tif  vit_b_lm      327      328      5   

  schema_attachment_id  train  validate  is_volumetric  processed  is_patch  
0     

In [33]:
# Define a reusable function to prepare training/validation data from OMERO table
def prepare_dataset_from_table(conn, df, output_dir, subset_type="training", tmp_dir=None):
    """
    Prepare dataset from tracking table
    
    Args:
        conn: OMERO connection
        df: DataFrame with tracking info
        output_dir: Base output directory
        subset_type: "training" or "val"
        tmp_dir: Temporary directory for downloading annotations
        
    Returns:
        (input_dir, label_dir): Paths to the input and label directories
    """
    if tmp_dir is None:
        tmp_dir = os.path.join(output_dir, "tmp")
        os.makedirs(tmp_dir, exist_ok=True)
        
    input_dir = os.path.join(output_dir, f"{subset_type}_input")
    label_dir = os.path.join(output_dir, f"{subset_type}_label")
    os.makedirs(input_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)
    
    try:
        from tqdm import tqdm
        progress_fn = tqdm
    except ImportError:
        # Simple progress function if tqdm is not available
        def progress_fn(x, **kwargs):
            return x
    
    for n in progress_fn(range(len(df)), desc=f"Preparing {subset_type} data"):
        try:
            # Extract metadata
            image_id = int(df.iloc[n]['image_id'])
            
            # Handle z_slice - could be int, string representation of list, or NaN
            z_slice = df.iloc[n]['z_slice']
            if pd.isna(z_slice):
                z_slice = 0
            elif isinstance(z_slice, str) and z_slice.startswith('['):
                try:
                    z_slice = eval(z_slice)
                    if isinstance(z_slice, list) and len(z_slice) > 0:
                        z_slice = z_slice[0]  # Use first slice for 2D
                except:
                    z_slice = 0
            
            # Handle other metadata columns
            channel = int(df.iloc[n]['channel']) if pd.notna(df.iloc[n]['channel']) else 0
            timepoint = int(df.iloc[n]['timepoint']) if pd.notna(df.iloc[n]['timepoint']) else 0
            is_volumetric = bool(df.iloc[n]['is_volumetric']) if 'is_volumetric' in df.columns and pd.notna(df.iloc[n]['is_volumetric']) else False
            
            # Get patch information
            is_patch = bool(df.iloc[n]['is_patch']) if 'is_patch' in df.columns and pd.notna(df.iloc[n]['is_patch']) else False
            patch_x = int(df.iloc[n]['patch_x']) if pd.notna(df.iloc[n]['patch_x']) else 0
            patch_y = int(df.iloc[n]['patch_y']) if pd.notna(df.iloc[n]['patch_y']) else 0
            patch_width = int(df.iloc[n]['patch_width']) if pd.notna(df.iloc[n]['patch_width']) else 0
            patch_height = int(df.iloc[n]['patch_height']) if pd.notna(df.iloc[n]['patch_height']) else 0
            
            # Debug patch dimensions
            print(f"Item {n} - Image ID: {image_id}, Patch: {is_patch}, Dimensions: {patch_width}x{patch_height} at ({patch_x},{patch_y})")
            
            # Process based on whether it's 3D volumetric or 2D
            if is_volumetric:
                # Handle 3D volumetric data
                # Determine which z-slices to load
                if isinstance(z_slice, list):
                    z_slices = z_slice
                elif z_slice == 'all':
                    # Get image object to determine size
                    omero_image, _ = ezomero.get_image(conn, image_id, no_pixels=True)
                    if not omero_image:
                        print(f"Warning: Image {image_id} not found, skipping")
                        continue
                    z_slices = range(omero_image.getSizeZ())
                else:
                    z_slices = [int(z_slice)]
                
                # Create empty 3D array to hold all z-slices
                img_3d = []
                
                # Load each z-slice using ezomero.get_image
                for z in z_slices:
                    z_val = int(z)
                    if is_patch and patch_width > 0 and patch_height > 0:
                        # Debug start_coords and axis_lengths
                        print(f"  3D Patch Request - start_coords: ({patch_x}, {patch_y}, {z_val}, {channel}, {timepoint}), dimensions: {patch_width}x{patch_height}")
                        
                        # Use ezomero.get_image to extract the patch for this z-slice
                        _, img_slice = ezomero.get_image(
                            conn,
                            image_id,
                            start_coords=(patch_x, patch_y, z_val, channel, timepoint),
                            axis_lengths=(patch_width, patch_height, 1, 1, 1),
                            xyzct=True  # Use XYZCT ordering
                        )
                        
                        # Check shape of returned array
                        print(f"  Returned array shape (before extraction): {img_slice.shape}")
                        
                        # The result will be 5D, extract just the 2D slice
                        img_slice = img_slice[:,:,:, 0, 0]  # Extract the single z-slice
                        print(f"  Extracted slice shape: {img_slice.shape}")
                    else:
                        # Get full plane for this z-slice
                        _, img_slice = ezomero.get_image(
                            conn,
                            image_id,
                            start_coords=(0, 0, z_val, channel, timepoint),
                            axis_lengths=(None, None, 1, 1, 1),
                            xyzct=True  # Use XYZCT ordering
                        )
                        # Check shape of returned array
                        print(f"  Full plane shape (before extraction): {img_slice.shape}")
                        
                        # The result will be 5D, extract just the 2D slice
                        img_slice = img_slice[0, 0, 0]  # Extract the single z-slice
                        print(f"  Extracted full plane shape: {img_slice.shape}")
                    
                    img_3d.append(img_slice)
                
                # Convert to numpy array
                img_3d = np.array(img_3d)
                print(f"  Final 3D array shape: {img_3d.shape}")
                
                # Normalize 16-bit to 8-bit
                max_val = img_3d.max()
                if max_val > 0:
                    img_8bit = ((img_3d) * (255.0 / max_val)).astype(np.uint8)
                else:
                    img_8bit = img_3d.astype(np.uint8)
                
                # Save as multi-page TIFF for 3D data
                output_path = os.path.join(input_dir, f"input_{n:05d}.tif")
                imwrite(output_path, img_8bit)
                print(f"  Saved 3D TIFF to {output_path} with shape {img_8bit.shape}")
                
            else:
                # Handle 2D data with patch support using ezomero.get_image
                if is_patch and patch_width > 0 and patch_height > 0:
                    # Use ezomero.get_image with appropriate coordinates and dimensions
                    z_val = z_slice if not isinstance(z_slice, list) else z_slice[0]
                    
                    # Debug start_coords and axis_lengths
                    print(f"  2D Patch Request - start_coords: ({patch_x}, {patch_y}, {z_val}, {channel}, {timepoint}), dimensions: {patch_width}x{patch_height}")
                    
                    _, img_data = ezomero.get_image(
                        conn,
                        image_id,
                        start_coords=(patch_x, patch_y, int(z_val), channel, timepoint),
                        axis_lengths=(patch_width, patch_height, 1, 1, 1),
                        xyzct=True
                    )
                    
                    # Check shape of returned array
                    print(f"  Returned array shape: {img_data.shape}")
                    
                    # The array is already in the right dimensions (width, height, z=1, c=1, t=1)
                    # We just need to remove the trailing dimensions
                    if len(img_data.shape) == 5:
                        # Take only the first (and only) z, c, t indices
                        img_data = img_data[:, :, 0, 0, 0]
                        # swap x and y dimensions in the numpy array
                        img_data = np.swapaxes(img_data, 0, 1)

                    
                    print(f"  Extracted 2D shape: {img_data.shape}")
                else:
                    # Get full plane
                    z_val = z_slice if not isinstance(z_slice, list) else z_slice[0]
                    
                    # Debug start_coords
                    print(f"  2D Full Image Request - start_coords: (0, 0, {z_val}, {channel}, {timepoint})")
                    
                    _, img_data = ezomero.get_image(
                        conn,
                        image_id,
                        start_coords=(0, 0, int(z_val), channel, timepoint),
                        axis_lengths=(None, None, 1, 1, 1),
                        xyzct=True 
                    )
                    
                    # Check shape of returned array 
                    print(f"  Returned array shape: {img_data.shape}")
                    
                    # Remove trailing dimensions
                    if len(img_data.shape) == 5:
                        img_data = img_data[:, :, 0, 0, 0]
                        img_data = np.swapaxes(img_data, 0, 1)
                    
                    print(f"  Extracted 2D shape: {img_data.shape}")
                
                # Normalize 16-bit to 8-bit
                max_val = img_data.max()
                if max_val > 0:
                    img_8bit = ((img_data) * (255.0 / max_val)).astype(np.uint8)
                else:
                    img_8bit = img_data.astype(np.uint8)
                
                # Save as TIFF
                output_path = os.path.join(input_dir, f"input_{n:05d}.tif")
                imwrite(output_path, img_8bit)
                print(f"  Saved 2D TIFF to {output_path} with shape {img_8bit.shape}")
            
            # Get the label file
            label_id = int(df.iloc[n]['label_id']) if pd.notna(df.iloc[n]['label_id']) else None
            if label_id:
                try:
                    file_path = ezomero.get_file_annotation(conn, label_id, tmp_dir)
                    if file_path:
                        label_dest = os.path.join(label_dir, f"label_{n:05d}.tif")
                        os.rename(file_path, label_dest)
                        # Check the size of the saved label
                        label_img = imread(label_dest)
                        print(f"  Label shape: {label_img.shape} saved to {label_dest}")
                    else:
                        print(f"  Warning: Label file for image {image_id} not downloaded")
                except Exception as e:
                    print(f"  Error downloading label file: {e}")
            else:
                print(f"  Warning: No label ID for image {image_id}")
                
        except Exception as e:
            print(f"Error processing {subset_type} item {n}: {e}")
    
    return input_dir, label_dir

# Clean up existing folders
folders = ["training_input", "training_label", "val_input", "val_label", "tmp"]	
for folder in folders:
    folder_path = os.path.join(output_directory, folder)
    if os.path.exists(folder_path) and os.path.isdir(folder_path):
        shutil.rmtree(folder_path)

# Create tmp directory
tmp_dir = os.path.join(output_directory, "tmp")
os.makedirs(tmp_dir, exist_ok=True)

# Prepare training and validation data
train_images = table[table['train'] == True]
val_images = table[table['validate'] == True]

print(f"Found {len(train_images)} training images and {len(val_images)} validation images")

# Process training data
training_input_dir, training_label_dir = prepare_dataset_from_table(
    conn, 
    train_images, 
    output_directory, 
    subset_type="training",
    tmp_dir=tmp_dir
)

# Process validation data
val_input_dir, val_label_dir = prepare_dataset_from_table(
    conn, 
    val_images, 
    output_directory, 
    subset_type="val",
    tmp_dir=tmp_dir
)

print("Training data successfully saved to:", output_directory)

Found 6 training images and 6 validation images


Preparing training data:   0%|          | 0/6 [00:00<?, ?it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/876bd6b7-f39d-44ea-be9a-4a44587c93a8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/876bd6b7-f39d-44ea-be9a-4a44587c93a8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


Item 0 - Image ID: 254, Patch: True, Dimensions: 256x256 at (367,364)
  2D Patch Request - start_coords: (367, 364, 0, 1, 10), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/876bd6b7-f39d-44ea-be9a-4a44587c93a8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data:  17%|█▋        | 1/6 [00:00<00:03,  1.39it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00000.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00000.tif
Item 1 - Image ID: 254, Patch: True, Dimensions: 256x256 at (1324,956)
  2D Patch Request - start_coords: (1324, 956, 0, 1, 10), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/de8c44b4-678e-4ae7-a57e-12f2f80b9095omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/de8c44b4-678e-4ae7-a57e-12f2f80b9095omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/de8c44b4-678e-4ae7-a57e-12f2f80b9095omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data:  33%|███▎      | 2/6 [00:01<00:02,  1.70it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/5ab43ca5-9763-4f1a-adf2-662ab7515e53omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/5ab43ca5-9763-4f1a-adf2-662ab7515e53omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00001.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00001.tif
Item 2 - Image ID: 256, Patch: True, Dimensions: 256x256 at (1586,1706)
  2D Patch Request - start_coords: (1586, 1706, 0, 1, 1), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/5ab43ca5-9763-4f1a-adf2-662ab7515e53omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data:  50%|█████     | 3/6 [00:01<00:01,  1.70it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/17734cac-e59c-4a04-81fb-1378d526e99domero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/17734cac-e59c-4a04-81fb-1378d526e99domero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00002.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00002.tif
Item 3 - Image ID: 256, Patch: True, Dimensions: 256x256 at (141,1146)
  2D Patch Request - start_coords: (141, 1146, 0, 1, 2), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/17734cac-e59c-4a04-81fb-1378d526e99domero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data:  67%|██████▋   | 4/6 [00:02<00:01,  1.78it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00003.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00003.tif
Item 4 - Image ID: 252, Patch: True, Dimensions: 256x256 at (1637,1510)
  2D Patch Request - start_coords: (1637, 1510, 0, 1, 8), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/f2836016-f551-462c-b591-9055c515fb2aomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/f2836016-f551-462c-b591-9055c515fb2aomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/f2836016-f551-462c-b591-9055c515fb2aomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data:  83%|████████▎ | 5/6 [00:02<00:00,  1.76it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/be0f5aea-df6b-44b4-a712-77099ca3360comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/be0f5aea-df6b-44b4-a712-77099ca3360comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00004.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00004.tif
Item 5 - Image ID: 252, Patch: True, Dimensions: 256x256 at (312,138)
  2D Patch Request - start_coords: (312, 138, 0, 1, 3), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/be0f5aea-df6b-44b4-a712-77099ca3360comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing training data: 100%|██████████| 6/6 [00:03<00:00,  1.75it/s]
Preparing training data: 100%|██████████| 6/6 [00:03<00:00,  1.75it/s]


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_input\input_00005.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\training_label\label_00005.tif


Preparing val data:   0%|          | 0/6 [00:00<?, ?it/s]

Item 0 - Image ID: 253, Patch: True, Dimensions: 256x256 at (1517,488)
  2D Patch Request - start_coords: (1517, 488, 0, 1, 11), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/7b523d1a-22bd-401b-8b27-a3d675d1a1aaomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/7b523d1a-22bd-401b-8b27-a3d675d1a1aaomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/7b523d1a-22bd-401b-8b27-a3d675d1a1aaomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data:  17%|█▋        | 1/6 [00:00<00:02,  1.78it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00000.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00000.tif
Item 1 - Image ID: 253, Patch: True, Dimensions: 256x256 at (432,912)
  2D Patch Request - start_coords: (432, 912, 0, 1, 12), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/40e3239d-d9eb-41ec-b74b-11435a29a234omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/40e3239d-d9eb-41ec-b74b-11435a29a234omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/40e3239d-d9eb-41ec-b74b-11435a29a234omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data:  33%|███▎      | 2/6 [00:01<00:02,  1.75it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/75ddc1ce-e252-4f5d-bfad-8fa30d47ed7eomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/75ddc1ce-e252-4f5d-bfad-8fa30d47ed7eomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00001.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00001.tif
Item 2 - Image ID: 251, Patch: True, Dimensions: 256x256 at (1615,1124)
  2D Patch Request - start_coords: (1615, 1124, 0, 1, 5), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/75ddc1ce-e252-4f5d-bfad-8fa30d47ed7eomero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data:  50%|█████     | 3/6 [00:01<00:01,  1.78it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00002.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00002.tif
Item 3 - Image ID: 251, Patch: True, Dimensions: 256x256 at (162,627)
  2D Patch Request - start_coords: (162, 627, 0, 1, 12), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/93fd1c4c-c444-4cda-9f4a-a586a4e459b8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/93fd1c4c-c444-4cda-9f4a-a586a4e459b8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/93fd1c4c-c444-4cda-9f4a-a586a4e459b8omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data:  67%|██████▋   | 4/6 [00:02<00:01,  1.83it/s]INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/6ff312ee-c38a-47df-8596-21b1e377ff1comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/6ff312ee-c38a-47df-8596-21b1e377ff1comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000


  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00003.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00003.tif
Item 4 - Image ID: 255, Patch: True, Dimensions: 256x256 at (1158,1150)
  2D Patch Request - start_coords: (1158, 1150, 0, 1, 9), dimensions: 256x256


INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/6ff312ee-c38a-47df-8596-21b1e377ff1comero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data:  83%|████████▎ | 5/6 [00:02<00:00,  1.75it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00004.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00004.tif
Item 5 - Image ID: 255, Patch: True, Dimensions: 256x256 at (5,506)
  2D Patch Request - start_coords: (5, 506, 0, 1, 8), dimensions: 256x256


INFO:omero.gateway:Registered 3f06f483-a349-4350-9782-d3fcaf74d3c6/c113cf5b-6830-4b7c-9d71-00a1ffacc270omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/c113cf5b-6830-4b7c-9d71-00a1ffacc270omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
INFO:omero.gateway:Unregistered 3f06f483-a349-4350-9782-d3fcaf74d3c6/c113cf5b-6830-4b7c-9d71-00a1ffacc270omero.api.RawPixelsStore -t -e 1.1:tcp -h 172.19.0.9 -p 34289 -t 60000
Preparing val data: 100%|██████████| 6/6 [00:03<00:00,  1.81it/s]

  Returned array shape: (256, 256, 1, 1, 1)
  Extracted 2D shape: (256, 256)
  Saved 2D TIFF to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_input\input_00005.tif with shape (256, 256)
  Label shape: (256, 256) saved to C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234\val_label\label_00005.tif
Training data successfully saved to: C:\Users\Maarten\micro-sam_models\micro-sam-20250604_112234





### Prepare data loader for the training

In [None]:
def determine_patch_shape_from_table(table_df, default_shape=(1, 512, 512), min_size=256):
    """
    Extract optimal patch shape from the OMERO table that contains patch dimensions
    
    Args:
        table_df: DataFrame from OMERO table with patch information
        default_shape: Default patch shape to use if no info available (default: (1, 512, 512))
        min_size: Minimum acceptable patch dimension (default: 256)
        
    Returns:
        tuple: Appropriate patch shape for training (C, H, W)
    """
    try:
        # Check if we have patch dimensions in the table
        if ('patch_width' in table_df.columns and 'patch_height' in table_df.columns and 
            not table_df['patch_width'].isna().all() and not table_df['patch_height'].isna().all()):
            
            # Get median patch dimensions from the table (to handle potential variation)
            patch_width = int(table_df['patch_width'].median())
            patch_height = int(table_df['patch_height'].median())
            
            # Validate dimensions (must be positive numbers)
            if patch_width > 0 and patch_height > 0:
                # Apply minimum size constraint
                patch_width = max(min_size, patch_width)
                patch_height = max(min_size, patch_height)
                
                # Ensure even dimensions for better compatibility
                patch_width = patch_width - (patch_width % 2)
                patch_height = patch_height - (patch_height % 2)
                
                new_shape = (1, patch_height, patch_width)
                print(f"Using patch shape {new_shape} extracted from OMERO table")
                return new_shape
            else:
                print(f"Invalid patch dimensions in table: {patch_width}x{patch_height}, using default {default_shape}")
                return default_shape
        else:
            print(f"No patch dimensions found in table, using default {default_shape}")
            return default_shape
    except Exception as e:
        print(f"Error determining patch shape from table: {e}, using default {default_shape}")
        return default_shape

batch_size = 2  # training batch size

# Determine patch shape from the OMERO table that contains our annotations
patch_shape = determine_patch_shape_from_table(
    table,
    default_shape=(1, 512, 512),
    min_size=256
)
print(f"Selected patch shape for training: {patch_shape}")

# Load images from multiple files in folder via pattern (here: all tif files)
raw_key, label_key = "*.tif", "*.tif"

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
# NOTE 1: It's important to have densely annotated-labels while training the additional convolutional decoder.
# NOTE 2: In case you do not have labeled images, we recommend using `micro-sam` annotator tools to annotate as many objects as possible per image for best performance.
train_instance_segmentation = True

# NOTE: The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth
# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances
# to the object centers and object boundaries for automatic segmentation.

# There are cases where our inputs are large and the labeled objects are not evenly distributed across the image.
# For this we use samplers, which ensure that valid inputs are chosen subjected to the paired labels.
# The sampler chosen below makes sure that the chosen inputs have atleast one foreground instance, and filters out small objects.
sampler = MinInstanceSampler(min_size=25)  # NOTE: The choice of 'min_size' value is paired with the same value in 'min_size' filter in 'label_transform'.

train_loader = sam_training.default_sam_loader(
    raw_paths=training_input_dir,
    raw_key=raw_key,
    label_paths=training_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=train_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

val_loader = sam_training.default_sam_loader(
    raw_paths=val_input_dir,
    raw_key=raw_key,
    label_paths=val_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=val_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)
check_loader(train_loader, 1, plt=True)
check_loader(val_loader, 1, plt=True)

### Running the training

In [None]:
n_objects_per_batch = 2  # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
n_epochs = 100  # how long we train (in epochs)
print('running on: ', device)
# The model_type determines which base model is used to initialize the weights that are finetuned.
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
model_type = "vit_l"

# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
checkpoint_name = "sam"

sam_training.train_sam(
    name=checkpoint_name,
    save_root=os.path.join(output_directory, "models"),
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    #checkpoint_path='C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503\\models\\checkpoints\\sam\\best.pt', #can be used to train further
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

### save model


In [None]:
print(output_directory)

### save as bioimage.io model

In [None]:
import numpy as np
from typing import Union
from micro_sam.bioimageio.model_export import export_sam_model

# Get a test image and label to use for exporting
# For this example, we'll use the first image and label from validation set
test_image_path = os.path.join(val_input_dir, os.listdir(val_input_dir)[0])
test_label_path = os.path.join(val_label_dir, os.listdir(val_label_dir)[0])

# Load the test image and label
test_image = np.array(imread(test_image_path))
test_label = np.array(imread(test_label_path))

# Define the path for saving the bioimage.io model
bioimageio_model_path = os.path.join(output_directory, "bioimage_io_model")
os.makedirs(bioimageio_model_path, exist_ok=True)

# Export the SAM model to bioimage.io format
export_sam_model(
    image=test_image,
    label_image=test_label,
    model_type=model_type,  # Using the same model_type as in training
    name=f"micro_sam_{timestamp}",
    output_path=bioimageio_model_path,
    checkpoint_path=os.path.join(
        output_directory, "models", "checkpoints", checkpoint_name, "best.pt"
    ),
    # Optional: Add additional kwargs as needed
    authors=[{"name": "Your Name", "affiliation": "Your Institution"}],
    description="Micro-SAM model trained on microscopy images for segmentation",
    license="MIT",
    documentation="Model trained with micro-sam for segmenting microscopy images",
)

print(f"BioImage.IO model exported to: {bioimageio_model_path}")
