### Prediction with finetuned model of micro-sam

In [None]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np
import datetime
# Import Python System Packages
import os
import tempfile
import pandas as pd
import warnings
from tifffile import imwrite
from dotenv import load_dotenv
import imageio
import shutil
import cv2
#micro-sam related imports
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)

conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

In [None]:
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
home_dir = os.path.expanduser("~")
models_dir = os.path.join(home_dir, "micro-sam_models")
os.makedirs(models_dir, exist_ok=True)
folder_name = f"micro-sam-{timestamp}"
output_directory = os.path.join(models_dir, folder_name)
os.makedirs(output_directory, exist_ok=True)
print(f"Output directory: {output_directory}")

In [None]:
### Get info from the dataset
datatype = "dataset" # "plate", "dataset", "image"
data_id =  	502 
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

In [None]:
def run_automatic_instance_segmentation(image, checkpoint_path, model_type="vit_b_lm", device=None):
    """Automatic Instance Segmentation (AIS) by training an additional instance decoder in SAM.

    NOTE: AIS is supported only for `µsam` models.

    Args:
        image: The input image.
        checkpoint_path: The path to stored checkpoints.
        model_type: The choice of the `µsam` model.
        device: The device to run the model inference.

    Returns:
        The instance segmentation.
    """
    # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
    predictor, segmenter = get_predictor_and_segmenter(
        model_type=model_type, # choice of the Segment Anything model
        checkpoint=checkpoint_path,  # overwrite to pass your own finetuned model.
        device=device,  # the device to run the model inference.
    )

    # Step 2: Get the instance segmentation for the given image.
    prediction = automatic_instance_segmentation(
        predictor=predictor,  # the predictor for the Segment Anything model.
        segmenter=segmenter,  # the segmenter class responsible for generating predictions.
        input_path=image,
        ndim=2,
    )

    return prediction

In [None]:
def process_omero_prediction_batch(
    dataset,
    output_folder: str,
    model_path: str,
    model_type: str = 'vit_l',
    batch_size: int = 3,
    channel: int = 0,
    timepoint: int = 0,
    z_slice: int = 0,
    ):
    """
    Process OMERO dataset in batches for automatic instance segmentation using fine-tuned SAM model
    
    Args:
        dataset: OMERO dataset object
        output_folder: Path to store temporary files
        model_path: Path to fine-tuned model checkpoint
        model_type: SAM model type
        batch_size: Number of images to process at once
        channel: Channel to segment
        timepoint: Timepoint to process
        z_slice: Z-slice to process
    """
    # Setup output directory
    output_path = os.path.join(output_folder, "predictions")
    
    # Remove directory if exists and create fresh
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.makedirs(output_path)
    
    # Get all images from dataset
    images_dataset = list(dataset.listChildren())
    total_batches = (len(images_dataset) + batch_size - 1) // batch_size

    # Create tracking dataframe
    df = pd.DataFrame(columns=[
        "image_id", "image_name", "channel", "timepoint", 
        "sam_model", "label_id", "roi_id"
    ])
    
    # Process images in batches
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(images_dataset))
        batch_images = images_dataset[start_idx:end_idx]
        
        # Process each image in batch
        for n, image in enumerate(batch_images):
            local_n = n
            
            # Get image plane
            pixels = image.getPrimaryPixels()
            img = pixels.getPlane(z_slice, channel, timepoint)
            
            # Run automatic instance segmentation
            prediction = run_automatic_instance_segmentation(
                image=img, 
                checkpoint_path=model_path,
                model_type=model_type,
                device='cuda'
            )
            
            # Save prediction
            pred_file = os.path.join(output_path, f"pred_{local_n:05d}.tif")
            imageio.imwrite(pred_file, prediction)
            
            # Upload prediction and ROIs
            label_id, roi_id = upload_prediction_and_rois(
                conn, 
                image, 
                pred_file, 
                z_slice, 
                channel, 
                timepoint, 
                model_type
            )
            
            # Update tracking dataframe
            new_row = pd.DataFrame([{
                "image_id": image.getId(),
                "image_name": image.getName(),
                "channel": channel,
                "z_slice": z_slice,
                "timepoint": timepoint,
                "sam_model": model_type,
                "label_id": label_id,
                "roi_id": roi_id
            }])
            df = pd.concat([df, new_row], ignore_index=True)
            
            # Clean up prediction file
            if os.path.exists(pred_file):
                os.remove(pred_file)
    
    # Upload tracking table
    table_id = ezomero.post_table(
        conn, 
        object_type="Dataset", 
        object_id=dataset.getId(), 
        table=df,
        title="micro_sam_prediction_data"
    )
    
    return table_id


def mask_to_contour(mask):
    """Converts a binary mask to a list of ROI coordinates.

    Args:
        mask (np.ndarray): binary mask

    Returns:
        list: list of ROI coordinates
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def label_to_rois(label_img, z_slice, channel, timepoint):
    """
    Convert a 2D label image to OMERO ROI shapes
    
    Args:
        label_img (np.ndarray): 2D labeled image
        z_slice (int): Z-slice index
        channel (int): Channel index
        timepoint (int): Time point index
    
    Returns:
        list: List of OMERO shape objects
    """
    shapes = []
    unique_labels = np.unique(label_img)
    
    # Skip background (label 0)
    for label in unique_labels[1:]:
        # Create binary mask for this label
        mask = (label_img == label).astype(np.uint8)
        
        # Get contours
        contours = mask_to_contour(mask)
        
        # Convert each contour to polygon ROI
        for contour in contours:
            contour = contour[:, 0, :]  # Reshape to (N, 2)
            # Create polygon without text parameter
            poly = ezomero.rois.Polygon(
                points=contour,  # explicitly name the points parameter
                z=z_slice,
                c=channel,
                t=timepoint
            )
            shapes.append(poly)
    
    return shapes

def upload_prediction_and_rois(conn, image, pred_file, z_slice, channel, timepoint, model_type):
    """
    Upload prediction map and ROIs for automatically segmented image
    """
    # Upload prediction as attachment
    label_id = ezomero.post_file_annotation(
        conn,
        str(pred_file),
        ns='microsam.automatic_prediction',
        object_type="Image",
        object_id=image.getId(),
        description=f'SAM automatic instance segmentation ({model_type})'
    )
    
    # Create ROIs from prediction
    pred_img = imageio.imread(pred_file)
    shapes = label_to_rois(pred_img, z_slice, channel, timepoint)
    
    if shapes:
        roi_id = ezomero.post_roi(
            conn,
            image.getId(),
            shapes,
            name=f'SAM_automatic_{model_type}',
            description='micro_sam.automatic_instance_segmentation'
        )
    else:
        roi_id = None
        
    return label_id, roi_id

In [None]:

model_folder = 'C:\\models\\checkpoints\\sam'
best_checkpoint = os.path.join(model_folder, "best.pt")
model_type = "vit_b"
channel = 3
batch_size = 1
timepoint = 0
z_slice = 5
table_id = process_omero_prediction_batch(
    dataset=dataset,
    output_folder=output_directory,
    model_path=best_checkpoint,
    model_type=model_type,
    batch_size=batch_size,
    channel=channel,
    timepoint=timepoint,
    z_slice=z_slice
)
print("Prediction Table ID:", table_id)