# Micro-SAM Training from OMERO Data

Train micro-SAM models using annotation tables from OMERO.

## 1. Setup

In [None]:
# Import the package
from omero_annotate_ai import (
    create_omero_connection_widget,
    create_training_data_widget
)

# Additional imports for training
import os
import datetime
from pathlib import Path
import torch

print(f"Available widgets: Connection, Training Data")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2. OMERO Connection

In [None]:
# Create and display OMERO connection widget
conn_widget = create_omero_connection_widget()
conn_widget.display()

In [None]:
# Get the OMERO connection
conn = conn_widget.get_connection()

if conn is None:
    raise ConnectionError("No OMERO connection established.")

print(f"Connected to OMERO as: {conn.getUser().getName()}")

## 3. Training Data Selection

In [None]:
# Create training data selection widget
training_widget = create_training_data_widget(connection=conn)
training_widget.display()

In [None]:
# Get selected training table
selected_table_id = training_widget.get_selected_table_id()
selected_table_info = training_widget.get_selected_table_info()

if selected_table_id:
    print(f"Selected training table:")
    print(f"  Table ID: {selected_table_id}")
    print(f"  Table Name: {selected_table_info.get('name', 'Unknown')}")
    print(f"  Created: {selected_table_info.get('created', 'Unknown')}")
else:
    raise ValueError("No training table selected. Please select a table above.")

## 4. Setup Training Directory

In [None]:
# Create output directory for training
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = Path.home()
models_dir = home_dir / "micro-sam_models"
models_dir.mkdir(exist_ok=True)

folder_name = f"micro-sam-{timestamp}"
output_directory = models_dir / folder_name
output_directory.mkdir(exist_ok=True)

print(f"Training output directory: {output_directory}")

## 5. Data Preparation (Manual Implementation)

This section contains the manual data preparation logic that will be replaced by the `prepare_training_data_from_table()` function.

In [None]:
# Import required modules for manual data preparation
import ezomero
import pandas as pd
import numpy as np
import shutil
from tifffile import imread, imwrite
from tqdm import tqdm

# Get the table data
table = ezomero.get_table(conn, selected_table_id)
print(f"Table contains {len(table)} rows")
print(f"Columns: {list(table.columns)}")
print(table.head())

In [None]:
def prepare_dataset_from_table(conn, df, output_dir, subset_type="training", tmp_dir=None):
    """
    Prepare dataset from tracking table
    
    Args:
        conn: OMERO connection
        df: DataFrame with tracking info
        output_dir: Base output directory
        subset_type: "training" or "val"
        tmp_dir: Temporary directory for downloading annotations
        
    Returns:
        (input_dir, label_dir): Paths to the input and label directories
    """
    if tmp_dir is None:
        tmp_dir = output_dir / "tmp"
        tmp_dir.mkdir(exist_ok=True)
        
    input_dir = output_dir / f"{subset_type}_input"
    label_dir = output_dir / f"{subset_type}_label"
    input_dir.mkdir(exist_ok=True)
    label_dir.mkdir(exist_ok=True)
    
    for n in tqdm(range(len(df)), desc=f"Preparing {subset_type} data"):
        try:
            # Extract metadata
            image_id = int(df.iloc[n]['image_id'])
            
            # Handle z_slice
            z_slice = df.iloc[n]['z_slice']
            if pd.isna(z_slice):
                z_slice = 0
            elif isinstance(z_slice, str) and z_slice.startswith('['):
                try:
                    z_slice = eval(z_slice)
                    if isinstance(z_slice, list) and len(z_slice) > 0:
                        z_slice = z_slice[0]
                except:
                    z_slice = 0
            
            # Handle other metadata
            channel = int(df.iloc[n]['channel']) if pd.notna(df.iloc[n]['channel']) else 0
            timepoint = int(df.iloc[n]['timepoint']) if pd.notna(df.iloc[n]['timepoint']) else 0
            
            # Get patch information
            is_patch = bool(df.iloc[n]['is_patch']) if 'is_patch' in df.columns and pd.notna(df.iloc[n]['is_patch']) else False
            patch_x = int(df.iloc[n]['patch_x']) if pd.notna(df.iloc[n]['patch_x']) else 0
            patch_y = int(df.iloc[n]['patch_y']) if pd.notna(df.iloc[n]['patch_y']) else 0
            patch_width = int(df.iloc[n]['patch_width']) if pd.notna(df.iloc[n]['patch_width']) else 0
            patch_height = int(df.iloc[n]['patch_height']) if pd.notna(df.iloc[n]['patch_height']) else 0
            
            # Get image data
            if is_patch and patch_width > 0 and patch_height > 0:
                _, img_data = ezomero.get_image(
                    conn,
                    image_id,
                    start_coords=(patch_x, patch_y, int(z_slice), channel, timepoint),
                    axis_lengths=(patch_width, patch_height, 1, 1, 1),
                    xyzct=True
                )
            else:
                _, img_data = ezomero.get_image(
                    conn,
                    image_id,
                    start_coords=(0, 0, int(z_slice), channel, timepoint),
                    axis_lengths=(None, None, 1, 1, 1),
                    xyzct=True
                )
            
            # Process image data
            if len(img_data.shape) == 5:
                img_data = img_data[:, :, 0, 0, 0]
                img_data = np.swapaxes(img_data, 0, 1)
            
            # Normalize to 8-bit
            max_val = img_data.max()
            if max_val > 0:
                img_8bit = ((img_data) * (255.0 / max_val)).astype(np.uint8)
            else:
                img_8bit = img_data.astype(np.uint8)
            
            # Save image
            output_path = input_dir / f"input_{n:05d}.tif"
            imwrite(output_path, img_8bit)
            
            # Get label file
            label_id = int(df.iloc[n]['label_id']) if pd.notna(df.iloc[n]['label_id']) else None
            if label_id:
                try:
                    file_path = ezomero.get_file_annotation(conn, label_id, str(tmp_dir))
                    if file_path:
                        label_dest = label_dir / f"label_{n:05d}.tif"
                        shutil.move(file_path, label_dest)
                except Exception as e:
                    print(f"Error downloading label file: {e}")
                    
        except Exception as e:
            print(f"Error processing {subset_type} item {n}: {e}")
    
    return input_dir, label_dir

In [None]:
# Clean up existing folders
folders = ["training_input", "training_label", "val_input", "val_label", "tmp"]
for folder in folders:
    folder_path = output_directory / folder
    if folder_path.exists():
        shutil.rmtree(folder_path)

# Prepare training and validation data
train_images = table[table['train'] == True]
val_images = table[table['validate'] == True]

print(f"Found {len(train_images)} training images and {len(val_images)} validation images")

# Process training data
training_input_dir, training_label_dir = prepare_dataset_from_table(
    conn, train_images, output_directory, subset_type="training"
)

# Process validation data
val_input_dir, val_label_dir = prepare_dataset_from_table(
    conn, val_images, output_directory, subset_type="val"
)

print(f"Training data prepared in: {output_directory}")

## 6. Data Loaders

In [None]:
import micro_sam.training as sam_training
from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader

def determine_patch_shape_from_table(table_df, default_shape=(1, 512, 512), min_size=256):
    """Extract optimal patch shape from OMERO table"""
    try:
        if ('patch_width' in table_df.columns and 'patch_height' in table_df.columns and 
            not table_df['patch_width'].isna().all() and not table_df['patch_height'].isna().all()):
            
            patch_width = int(table_df['patch_width'].median())
            patch_height = int(table_df['patch_height'].median())
            
            if patch_width > 0 and patch_height > 0:
                patch_width = max(min_size, patch_width)
                patch_height = max(min_size, patch_height)
                
                # Ensure even dimensions
                patch_width = patch_width - (patch_width % 2)
                patch_height = patch_height - (patch_height % 2)
                
                new_shape = (1, patch_height, patch_width)
                print(f"Using patch shape {new_shape} from table")
                return new_shape
        
        print(f"Using default patch shape {default_shape}")
        return default_shape
    except Exception as e:
        print(f"Error determining patch shape: {e}, using default {default_shape}")
        return default_shape

# Training parameters
batch_size = 2
patch_shape = determine_patch_shape_from_table(table)
train_instance_segmentation = True
sampler = MinInstanceSampler(min_size=25)

# Create data loaders
train_loader = sam_training.default_sam_loader(
    raw_paths=str(training_input_dir),
    raw_key="*.tif",
    label_paths=str(training_label_dir),
    label_key="*.tif",
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

val_loader = sam_training.default_sam_loader(
    raw_paths=str(val_input_dir),
    raw_key="*.tif",
    label_paths=str(val_label_dir),
    label_key="*.tif",
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

# Check data loaders
check_loader(train_loader, 1, plt=True)
check_loader(val_loader, 1, plt=True)

print(f"Data loaders created successfully")

## 7. Training

In [None]:
# Training parameters
n_objects_per_batch = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 100
model_type = "vit_b_lm"
checkpoint_name = "sam"

print(f"Training on: {device}")
print(f"Model type: {model_type}")
print(f"Epochs: {n_epochs}")

# Run training
sam_training.train_sam(
    name=checkpoint_name,
    save_root=str(output_directory / "models"),
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

print(f"Training completed. Model saved to: {output_directory / 'models'}")

## 8. Export Model

In [None]:
from micro_sam.bioimageio.model_export import export_sam_model

# Get test images for export
test_image_path = val_input_dir / list(val_input_dir.glob("*.tif"))[0].name
test_label_path = val_label_dir / list(val_label_dir.glob("*.tif"))[0].name

# Load test data
test_image = np.array(imread(test_image_path))
test_label = np.array(imread(test_label_path))

# Export to bioimage.io format
bioimageio_model_path = output_directory / "bioimage_io_model"
bioimageio_model_path.mkdir(exist_ok=True)

export_sam_model(
    image=test_image,
    label_image=test_label,
    model_type=model_type,
    name=f"micro_sam_{timestamp}",
    output_path=str(bioimageio_model_path),
    checkpoint_path=str(output_directory / "models" / "checkpoints" / checkpoint_name / "best.pt"),
    authors=[{"name": "User", "affiliation": "Institution"}],
    description="Micro-SAM model trained on OMERO data",
    license="MIT",
    documentation="Model trained with micro-sam",
)

print(f"Model exported to: {bioimageio_model_path}")

## 9. Cleanup

In [None]:
# Close OMERO connection
if conn:
    conn.close()
    print("OMERO connection closed")

print(f"Training completed successfully!")
print(f"Output directory: {output_directory}")
print(f"Trained model: {output_directory / 'models' / 'checkpoints' / checkpoint_name / 'best.pt'}")
print(f"BioImage.IO model: {bioimageio_model_path}")