# Micro-SAM Run Inference on OMERO Data

Example notebook to try the finetuned micro_sam model on OMERO data

## 1. Setup 
Run these cells to import all required packages and functions

In [None]:
# Import the omero login widget
from omero_annotate_ai import (
    create_omero_connection_widget,
)

# Additional imports
import datetime
from pathlib import Path
import torch
import ezomero
import stackview
import numpy as np

from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")


In [None]:
def run_automatic_instance_segmentation(image, checkpoint_path, model_type="vit_b_lm", device=None):
    """Automatic Instance Segmentation (AIS) by training an additional instance decoder in SAM.

    NOTE: AIS is supported only for `µsam` models.

    Args:
        image: The input image.
        checkpoint_path: The path to stored checkpoints.
        model_type: The choice of the `µsam` model.
        device: The device to run the model inference.

    Returns:
        The instance segmentation.
    """
    # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
    predictor, segmenter = get_predictor_and_segmenter(
        model_type=model_type, # choice of the Segment Anything model
        checkpoint=checkpoint_path,  # overwrite to pass your own finetuned model.
        device=device,  # the device to run the model inference.
    )

    # Step 2: Get the instance segmentation for the given image.
    prediction = automatic_instance_segmentation(
        predictor=predictor,  # the predictor for the Segment Anything model.
        segmenter=segmenter,  # the segmenter class responsible for generating predictions.
        input_path=image,
        ndim=2,
    )

    return prediction

## 2. OMERO Connection

In [None]:
# Create and display OMERO connection widget
conn_widget = create_omero_connection_widget()
conn_widget.display()

In [None]:
# After setting up the setting we need to setup the OMERO connection
conn = conn_widget.get_connection()

if conn is None:
    raise ConnectionError("No OMERO connection established.")

print(f"Connected to OMERO as: {conn.getUser().getName()}")

## 3. Data Selection

In [None]:
# Select image and a channel, timepoint z-slice (currently only 2D)
image_id = 277
channel = 0
time_point = 66
z_slice = 0
checkpoint_path = r'micro_sam_training_20250829_135659_final.pt'

#uncomment the line below to compare with original model before finetuning
#checkpoint_path = None



In [None]:
#get image from OMERO
image = conn.getObject("Image", image_id)
__,pixels = ezomero.get_image(conn, image_id, start_coords=[0,0,z_slice,channel,time_point], axis_lengths=[image.getSizeX(),image.getSizeY(),1,1,1])
#Run prediction
labels = run_automatic_instance_segmentation(np.squeeze(pixels), checkpoint_path, model_type="vit_b_lm", device=None)

In [None]:
#Here you can see your image with labels overlayed
stackview.curtain(np.squeeze(pixels), labels, alpha=0.5, continuous_update=True,zoom_factor=0.5)

An example how you can run the prediction on a stack 

In [None]:
# Select image
image_id = 277
channel = 0
time_point = 0
time_point_length = 30
z_slice = 0
checkpoint_path = r'micro_sam_training_20250829_135659_final.pt'


In [None]:
#get image from OMERO
image = conn.getObject("Image", image_id)
__,pixels = ezomero.get_image(conn, image_id, start_coords=[0,0,z_slice,channel,time_point], axis_lengths=[image.getSizeX(),image.getSizeY(),1,1,time_point_length])
#Run prediction
labels = []
for i in range(np.shape(pixels)[0]):
    labels.append(run_automatic_instance_segmentation(np.squeeze(pixels[i]), checkpoint_path, model_type="vit_b_lm", device=None))

In [None]:
#Now you can also slice through the images
stackview.curtain(np.squeeze(pixels), labels, alpha=0.5, continuous_update=True,zoom_factor=0.5)