# Micro-SAM Training from OMERO Data (Automated)

Train micro-SAM models using annotation tables from OMERO with automated data preparation.

## 1. Setup

In [1]:
# Import the package
from omero_annotate_ai import (
    create_omero_connection_widget,
    create_training_data_widget,
    prepare_training_data_from_table
)

# Additional imports for training
import os
import datetime
from pathlib import Path
import torch

print(f"Available widgets: Connection, Training Data")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

Available widgets: Connection, Training Data
Device available: CPU


## 2. OMERO Connection

In [None]:
# Create and display OMERO connection widget
conn_widget = create_omero_connection_widget()
conn_widget.display()

📄 Loaded configuration from connection history: root@localhost
🔐 Password loaded from keychain (no expiration)


VBox(children=(HTML(value='<h3>🔌 OMERO Server Connection</h3>', layout=Layout(margin='0 0 20px 0')), HTML(valu…

In [3]:
# Get the OMERO connection
conn = conn_widget.get_connection()

if conn is None:
    raise ConnectionError("No OMERO connection established.")

print(f"Connected to OMERO as: {conn.getUser().getName()}")

Connected to OMERO as: root


## 3. Training Data Selection

In [4]:
# Create training data selection widget
training_widget = create_training_data_widget(connection=conn)
training_widget.display()

VBox(children=(HTML(value='<h3>🎯 Training Data Selection</h3>', layout=Layout(margin='0 0 20px 0')), HTML(valu…

In [6]:
# Get selected training table
selected_table_id = training_widget.get_selected_table_id()
selected_table_info = training_widget.get_selected_table_info()

if selected_table_id:
    print(f"Selected training table:")
    print(f"  Table ID: {selected_table_id}")
    print(f"  Table Name: {selected_table_info.get('name', 'Unknown')}")
    print(f"  Created: {selected_table_info.get('created', 'Unknown')}")
else:
    raise ValueError("No training table selected. Please select a table above.")

Selected training table:
  Table ID: 1331
  Table Name: micro_sam_training_micro_sam_annotation
  Created: Unknown


## 4. Setup Training Directory

In [7]:
# Create output directory for training
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = Path.home()
models_dir = home_dir / "micro-sam_models"
models_dir.mkdir(exist_ok=True)

folder_name = f"micro-sam-{timestamp}"
output_directory = models_dir / folder_name
output_directory.mkdir(exist_ok=True)

print(f"Training output directory: {output_directory}")

Training output directory: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956


## 5. Automated Data Preparation

Use the automated data preparation function to download and organize training data.

In [8]:
# Run automated data preparation
try:
    training_result = prepare_training_data_from_table(
        conn=conn,
        table_id=selected_table_id,
        output_dir=output_directory,
        validation_split=0.2,  # 20% for validation
        clean_existing=True
    )
    
    print("\nTraining data preparation completed successfully!")
    print(f"\nDataset statistics:")
    for key, value in training_result['stats'].items():
        print(f"  {key}: {value}")
    
    # Store paths for later use in training
    training_input_dir = training_result['training_input']
    training_label_dir = training_result['training_label']
    val_input_dir = training_result['val_input']
    val_label_dir = training_result['val_label']
    
    print(f"\nDirectory structure created:")
    print(f"  Training images: {training_input_dir}")
    print(f"  Training labels: {training_label_dir}")
    print(f"  Validation images: {val_input_dir}")
    print(f"  Validation labels: {val_label_dir}")
    
except Exception as e:
    print(f"Error during data preparation: {e}")
    raise

Loaded table with 6 rows
Using 3 training images and 3 validation images


Preparing training data: 100%|██████████| 3/3 [00:00<00:00,  6.27it/s]


Error processing training item 0: 'NoneType' object cannot be interpreted as an integer
Error processing training item 1: 'NoneType' object cannot be interpreted as an integer
Error processing training item 2: 'NoneType' object cannot be interpreted as an integer


Preparing val data:   0%|          | 0/3 [00:00<?, ?it/s]

Error processing val item 0: 'NoneType' object cannot be interpreted as an integer


Preparing val data: 100%|██████████| 3/3 [00:00<00:00, 18.23it/s]

Error processing val item 1: 'NoneType' object cannot be interpreted as an integer
Error processing val item 2: 'NoneType' object cannot be interpreted as an integer
Training data prepared successfully in: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956
Statistics: {'n_training_images': 0, 'n_training_labels': 0, 'n_val_images': 0, 'n_val_labels': 0, 'total_rows_processed': 6}

Training data preparation completed successfully!

Dataset statistics:
  n_training_images: 0
  n_training_labels: 0
  n_val_images: 0
  n_val_labels: 0
  total_rows_processed: 6

Directory structure created:
  Training images: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956\training_input
  Training labels: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956\training_label
  Validation images: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956\val_input
  Validation labels: C:\Users\Maarten\micro-sam_models\micro-sam-20250822_233956\val_label





## 6. Micro-SAM Training Setup

Configure and run micro-SAM training using the prepared data.

In [None]:
# Import micro-SAM training modules
import micro_sam.training as sam_training
from torch_em.data import MinInstanceSampler

# Set training parameters
n_objects_per_batch = 25
patch_shape = (512, 512)
batch_size = 2
learning_rate = 1e-5
n_iterations = 10000

print(f"Training configuration:")
print(f"  Patch shape: {patch_shape}")
print(f"  Batch size: {batch_size}")
print(f"  Learning rate: {learning_rate}")
print(f"  Iterations: {n_iterations}")
print(f"  Objects per batch: {n_objects_per_batch}")

In [None]:
# Create data loaders
sampler = MinInstanceSampler()

train_loader = sam_training.get_sam_loader(
    raw_paths=str(training_input_dir),
    raw_key="*.tif",
    label_paths=str(training_label_dir),
    label_key="*.tif",
    patch_shape=patch_shape,
    batch_size=batch_size,
    sampler=sampler,
    n_samples=n_objects_per_batch,
)

val_loader = sam_training.get_sam_loader(
    raw_paths=str(val_input_dir),
    raw_key="*.tif",
    label_paths=str(val_label_dir),
    label_key="*.tif",
    patch_shape=patch_shape,
    batch_size=batch_size,
    sampler=sampler,
    n_samples=n_objects_per_batch,
)

print("Data loaders created successfully!")

In [None]:
# Run training
model_name = f"micro_sam_training_{timestamp}"
checkpoint_folder = output_directory / "checkpoints"
checkpoint_folder.mkdir(exist_ok=True)

print(f"Starting micro-SAM training...")
print(f"Model name: {model_name}")
print(f"Checkpoint folder: {checkpoint_folder}")

# Run the training
sam_training.run_sam_training(
    name=model_name,
    model_type="vit_b",
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_folder=str(checkpoint_folder),
    learning_rate=learning_rate,
    n_iterations=n_iterations,
    save_every=1000,
    validate_every=500,
)

print("Training completed!")

## 7. Model Export and Summary

In [None]:
# Find the best checkpoint
checkpoints = list(checkpoint_folder.glob("*.pt"))
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"Latest checkpoint: {latest_checkpoint}")
    
    # Export model for inference
    export_path = output_directory / f"{model_name}_final.pt"
    print(f"Model exported to: {export_path}")
else:
    print("No checkpoints found.")

print(f"\nTraining summary:")
print(f"  Output directory: {output_directory}")
print(f"  Model name: {model_name}")
print(f"  Training completed with {n_iterations} iterations")
print(f"  Dataset statistics: {training_result['stats']}")

## 8. Cleanup

In [None]:
# Close OMERO connection
if conn is not None:
    conn.close()
    print("OMERO connection closed.")
else:
    print("No active OMERO connection to close.")