## Training

In [None]:
from omero.gateway import BlitzGateway
import ezomero
#load dotenv for OMERO login
from dotenv import load_dotenv

from tifffile import imsave, imwrite, imread
import torch

import os
import tempfile
import pandas as pd
import zipfile
import numpy as np
import datetime
import shutil

from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.util.util import get_random_colors

import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)

conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), group=os.environ.get("GROUP"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 1112
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

### Define output folder for training

In [None]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = os.path.expanduser("~")
models_dir = os.path.join(home_dir, "micro-sam_models")
os.makedirs(models_dir, exist_ok=True)
folder_name = f"micro-sam-{timestamp}"
output_directory = os.path.join(models_dir, folder_name)
os.makedirs(output_directory, exist_ok=True)
output_directory = os.path.abspath(output_directory)
#output_directory = os.path.abspath("C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503")
print(f"Output directory: {output_directory}")

### Collecting data from OMERO using the attached table

In [None]:
def get_specific_table(conn, dataset_id, table_name="micro_sam_training_data"):
    """
    Find and return a specific table attached to a dataset by its name.
    
    Args:
        conn: OMERO connection
        dataset_id: ID of the dataset to search
        table_name: Name of the table file to find
        
    Returns:
        table: Table data as pandas DataFrame or list of lists
        file_ann_id: ID of the file annotation containing the table
    """
    # Get all file annotations on the dataset
    file_ann_ids = ezomero.get_file_annotation_ids(conn, "Dataset", dataset_id)
    
    # Get original file details to check names
    for ann_id in file_ann_ids:
        ann = conn.getObject("FileAnnotation", ann_id)
        if ann is None:
            continue
            
        orig_file = ann.getFile()
        if orig_file.getName() == table_name:
            try:
                table = ezomero.get_table(conn, ann_id)
                return table, ann_id
            except Exception as e:
                print(f"Found file {table_name} but failed to load as table: {e}")
                continue
                
    return None, None

In [None]:
table_name = "micro_sam_training_data"
table, file_ann_id = get_specific_table(conn, data_id, table_name)
if table is not None:
    print(f"Found table {table_name} in file annotation {file_ann_id}")
    # If pandas DataFrame:
    print(table.head())
else:
    print(f"No table named {table_name} found")

In [None]:
#download table from omero, use it to collect training data
train_images = []
validate_images = []

folders = ["training_input", "training_label", "val_input", "val_label", "tmp"]	
for folder in folders:
    folder = os.path.join(output_directory,folder)
    if os.path.exists(folder) and os.path.isdir(folder):
        shutil.rmtree(folder)
    #os.makedirs(folder)

#prepare training data
train_images = table[table['train'] == True]
val_images = table[table['validate'] == True]
os.makedirs(os.path.join(output_directory, "tmp"), exist_ok=True)
training_input_dir = os.path.join(output_directory, "training_input")
os.makedirs(training_input_dir, exist_ok=True)
training_label_dir = os.path.join(output_directory, "training_label")
os.makedirs(training_label_dir, exist_ok=True)

for n in range(len(train_images)):
    z_slice = train_images.iloc[n]['z_slice']
    channel = train_images.iloc[n]['channel']
    timepoint = train_images.iloc[n]['timepoint']
    image = conn.getObject('Image', int(train_images.iloc[n]['image_id']))
    pixels = image.getPrimaryPixels()
    img = pixels.getPlane(z_slice, channel, timepoint) #(z, c, t) 
    #save image to output folder
    # Normalize 16-bit to 8-bit using 0 as minimum
    img_8bit = ((img) * (255.0 / img.max())).astype(np.uint8)

    # Save as 8-bit tiff as required for micro-sam training
    imwrite(os.path.join(output_directory, "training_input", f"input_0000{n}.tif"), img_8bit)
    
    file_path = ezomero.get_file_annotation(conn, int(train_images.iloc[n]['label_id']), os.path.join(output_directory, "tmp"))
    os.rename(file_path, os.path.join(output_directory, "training_label", f"label_0000{n}.tif"))

val_input_dir = os.path.join(output_directory, "val_input")
os.makedirs(val_input_dir, exist_ok=True)
val_label_dir = os.path.join(output_directory, "val_label")
os.makedirs(val_label_dir, exist_ok=True) 

for n in range(len(val_images)):
    image = conn.getObject('Image', int(val_images.iloc[n]['image_id']))
    pixels = image.getPrimaryPixels()
    img = pixels.getPlane(z_slice, channel, timepoint) #(z, c, t) 
    # Normalize 16-bit to 8-bit using 0 as minimum
    img_8bit = ((img) * (255.0 / img.max())).astype(np.uint8)
    #save image to output folder
    imsave(os.path.join(output_directory, "val_input", f"input_0000{n}.tif"), img_8bit)
    file_path = ezomero.get_file_annotation(conn, int(val_images.iloc[n]['label_id']), os.path.join(output_directory, "tmp"))
    os.rename(file_path, os.path.join(output_directory, "val_label", f"label_0000{n}.tif"))

print("Training data succesfully saved to: ", output_directory)

### Prepare data loader for the training

In [None]:
batch_size = 2  # training batch size
patch_shape = (1, 512, 512)  # the size of patches for training
# Load images from multiple files in folder via pattern (here: all tif files)
raw_key, label_key = "*.tif", "*.tif"

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
# NOTE 1: It's important to have densely annotated-labels while training the additional convolutional decoder.
# NOTE 2: In case you do not have labeled images, we recommend using `micro-sam` annotator tools to annotate as many objects as possible per image for best performance.
train_instance_segmentation = True

# NOTE: The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth
# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances
# to the object centers and object boundaries for automatic segmentation.

# There are cases where our inputs are large and the labeled objects are not evenly distributed across the image.
# For this we use samplers, which ensure that valid inputs are chosen subjected to the paired labels.
# The sampler chosen below makes sure that the chosen inputs have atleast one foreground instance, and filters out small objects.
sampler = MinInstanceSampler(min_size=25)  # NOTE: The choice of 'min_size' value is paired with the same value in 'min_size' filter in 'label_transform'.

train_loader = sam_training.default_sam_loader(
    raw_paths=training_input_dir,
    raw_key=raw_key,
    label_paths=training_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=train_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

val_loader = sam_training.default_sam_loader(
    raw_paths=val_input_dir,
    raw_key=raw_key,
    label_paths=val_label_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    #rois=val_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)
check_loader(train_loader, 1, plt=True)
check_loader(val_loader, 1, plt=True)

### Running the training

In [None]:
n_objects_per_batch = 2  # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
n_epochs = 100  # how long we train (in epochs)
print('running on: ', device)
# The model_type determines which base model is used to initialize the weights that are finetuned.
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
model_type = "vit_l"

# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
checkpoint_name = "sam"

sam_training.train_sam(
    name=checkpoint_name,
    save_root=os.path.join(output_directory, "models"),
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    #checkpoint_path='C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503\\models\\checkpoints\\sam\\best.pt', #can be used to train further
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

### save model


In [None]:
print(output_directory)

### save as bioimage.io model

In [None]:
import numpy as np
from typing import Union
from micro_sam.bioimageio.model_export import export_sam_model

# Get a test image and label to use for exporting
# For this example, we'll use the first image and label from validation set
test_image_path = os.path.join(val_input_dir, os.listdir(val_input_dir)[0])
test_label_path = os.path.join(val_label_dir, os.listdir(val_label_dir)[0])

# Load the test image and label
test_image = np.array(imread(test_image_path))
test_label = np.array(imread(test_label_path))

# Define the path for saving the bioimage.io model
bioimageio_model_path = os.path.join(output_directory, "bioimage_io_model")
os.makedirs(bioimageio_model_path, exist_ok=True)

# Export the SAM model to bioimage.io format
export_sam_model(
    image=test_image,
    label_image=test_label,
    model_type=model_type,  # Using the same model_type as in training
    name=f"micro_sam_{timestamp}",
    output_path=bioimageio_model_path,
    checkpoint_path=os.path.join(
        output_directory, "models", "checkpoints", checkpoint_name, "best.pt"
    ),
    # Optional: Add additional kwargs as needed
    authors=[{"name": "Your Name", "affiliation": "Your Institution"}],
    description="Micro-SAM model trained on microscopy images for segmentation",
    license="MIT",
    documentation="Model trained with micro-sam for segmenting microscopy images",
)

print(f"BioImage.IO model exported to: {bioimageio_model_path}")
