# Micro-SAM Training from OMERO Data

Train micro-SAM models using annotation tables from OMERO with automated data preparation.

## 1. Setup

In [None]:
# Import the package with training convenience functions
from omero_annotate_ai import (
    create_omero_connection_widget,
    create_training_data_widget,
    prepare_training_data_from_table,
    setup_training,    # Convenience function for training setup
    run_training       # Convenience function for training execution
)

# Additional imports
import datetime
from pathlib import Path
import torch

print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")


## 2. OMERO Connection

In [None]:
# Create and display OMERO connection widget
conn_widget = create_omero_connection_widget()
conn_widget.display()

In [None]:
# Get the OMERO connection
conn = conn_widget.get_connection()

if conn is None:
    raise ConnectionError("No OMERO connection established.")

print(f"Connected to OMERO as: {conn.getUser().getName()}")

## 3. Training Data Selection

In [None]:
# Create training data selection widget
training_widget = create_training_data_widget(connection=conn)
training_widget.display()

In [None]:
# Get selected training table
selected_table_id = training_widget.get_selected_table_id()
selected_table_info = training_widget.get_selected_table_info()

if selected_table_id:
    print(f"Selected training table:")
    print(f"  Table ID: {selected_table_id}")
    print(f"  Table Name: {selected_table_info.get('name', 'Unknown')}")
    print(f"  Created: {selected_table_info.get('created', 'Unknown')}")
else:
    raise ValueError("No training table selected. Please select a table above.")

## 4. Setup Training Directory

In [None]:
# Create output directory for training
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = Path.home()
models_dir = home_dir / "micro-sam_models"
models_dir.mkdir(exist_ok=True)

folder_name = f"micro-sam-{timestamp}"
output_directory = models_dir / folder_name
output_directory.mkdir(exist_ok=True)

print(f"Training output directory: {output_directory}")

## 5. Automated Data Preparation

Use the automated data preparation function to download and organize training data.

In [None]:
# Run automated data preparation
try:
    training_result = prepare_training_data_from_table(
        conn=conn,
        table_id=selected_table_id,
        training_name= selected_table_info.get('name', f"training_table_{selected_table_id}"),
        output_dir=output_directory,
        clean_existing=True
    )
    
    print("\nTraining data preparation completed successfully!")
    print(f"\nDataset statistics:")
    for key, value in training_result['stats'].items():
        print(f"  {key}: {value}")
    
    # Store paths for later use in training
    training_input_dir = training_result['training_input']
    training_label_dir = training_result['training_label']
    val_input_dir = training_result['val_input']
    val_label_dir = training_result['val_label']
    
    print(f"\nDirectory structure created:")
    print(f"  Training images: {training_input_dir}")
    print(f"  Training labels: {training_label_dir}")
    print(f"  Validation images: {val_input_dir}")
    print(f"  Validation labels: {val_label_dir}")
    
except Exception as e:
    print(f"Error during data preparation: {e}")
    raise

## 6. Micro-SAM Training Setup

Configure and run micro-SAM training using the prepared data.

In [None]:
# ✨ Setup training configuration using convenience function
training_config = setup_training(
    training_result,
    model_name=f"{selected_table_info.get('name', 'micro_sam_training')}_{timestamp}",
    epochs=10,               # Primary parameter: number of epochs (use 50+ for real training)
    batch_size=1,            # Adjust based on GPU memory
    learning_rate=1e-5,      # Conservative learning rate
    patch_shape=(512, 512),  # Input patch size
    model_type="vit_b_lm",       # SAM model variant
    n_objects_per_batch=25   # Objects per batch for sampling
)

print("Training configuration prepared!")
print(f'Model name: {training_config["model_name"]}')
print(f'Output directory: {training_config["output_dir"]}')
print(f'Training epochs: {training_config["epochs"]}')
print(f'Calculated iterations: {training_config["n_iterations"]}')


In [None]:
# ✨ Execute training with convenience function
print("Starting micro-SAM training...")

training_results = run_training(training_config, framework="microsam")

print(f'🎉 Training completed successfully!')
print(f'Training Results:')
print(f'  Model name: {training_results["model_name"]}')
print(f'  Final model: {training_results.get("final_model_path", "Not available")}')
print(f'  Checkpoints saved: {len(training_results.get("checkpoints", []))}')
print(f'  Output directory: {training_results["output_dir"]}')


## 7. Model Export and Summary

In [None]:
# Find the best checkpoint
checkpoints = list(checkpoint_folder.glob("*.pt"))
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"Latest checkpoint: {latest_checkpoint}")
    
    # Export model for inference
    export_path = output_directory / f"{model_name}_final.pt"
    print(f"Model exported to: {export_path}")
else:
    print("No checkpoints found.")

print(f"\nTraining summary:")
print(f"  Output directory: {output_directory}")
print(f"  Model name: {model_name}")
print(f"  Training completed with {n_iterations} iterations")
print(f"  Dataset statistics: {training_result['stats']}")

In [None]:
import numpy as np
from typing import Union
from micro_sam.bioimageio.model_export import export_sam_model
import os
from tifffile import imread
import imageio.v3 as imageio
import micro_sam.util as util
from micro_sam.bioimageio.model_export import _create_test_inputs_and_outputs

# Get a test image and label to use for exporting
# For this example, we'll use the first image and label from validation set
test_image_path = os.path.join("D:\\Maarten\\Data\\HR_sensor\\micro-sam-20250829_134050\\val_input", os.listdir("D:\\Maarten\\Data\\HR_sensor\\micro-sam-20250829_134050\\val_input")[0])
test_label_path = os.path.join("D:\\Maarten\\Data\\HR_sensor\\micro-sam-20250829_134050\\val_label", os.listdir("D:\\Maarten\\Data\\HR_sensor\\micro-sam-20250829_134050\\val_label")[0])
output_directory = "D:\\Maarten\\Data\\HR_sensor\\micro-sam-20250829_134050"
# Load the test image and label
test_image = imageio.imread(test_image_path)
test_label = np.array(imread(test_label_path))
model_type="vit_b"
checkpoint_name = "micro_sam_training_20250829_135659"
# Define the path for saving the bioimage.io model
bioimageio_model_path = os.path.join(output_directory, "bioimage_io_model")
#os.makedirs(bioimageio_model_path, exist_ok=True)

# Export the SAM model to bioimage.io format
export_sam_model(
    image=test_image,
    label_image=test_label,
    model_type=model_type,  # Using the same model_type as in training
    name=f"micro_sam_test",
    output_path=bioimageio_model_path,
    checkpoint_path=os.path.join(
        output_directory, "checkpoints", checkpoint_name, "best.pt"
    ),
    # Optional: Add additional kwargs as needed
    authors=[{"name": "Your Name", "affiliation": "Your Institution"}],
    description="Micro-SAM model trained on microscopy images for segmentation",
    license="MIT",
    documentation="Model trained with micro-sam for segmenting microscopy images",
)

print(f"BioImage.IO model exported to: {bioimageio_model_path}")


## 8. Cleanup

In [None]:
# Close OMERO connection
if conn is not None:
    conn.close()
    print("OMERO connection closed.")
else:
    print("No active OMERO connection to close.")