### Prediction with finetuned model of micro-sam

In [None]:
# Import Numpy
import numpy as np
import datetime
# Import Python System Packages
import os
import tempfile
import pandas as pd
import warnings
from tifffile import imwrite
import imageio
import shutil
import cv2
#micro-sam related imports
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

In [None]:
#choose a path to 
os.path.abspath(
print(f"Data directory: {output_directory}")

In [None]:
def run_automatic_instance_segmentation(image, checkpoint_path, model_type="vit_b_lm", device=None):
    """Automatic Instance Segmentation (AIS) by training an additional instance decoder in SAM.

    NOTE: AIS is supported only for `µsam` models.

    Args:
        image: The input image.
        checkpoint_path: The path to stored checkpoints.
        model_type: The choice of the `µsam` model.
        device: The device to run the model inference.

    Returns:
        The instance segmentation.
    """
    # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
    predictor, segmenter = get_predictor_and_segmenter(
        model_type=model_type, # choice of the Segment Anything model
        checkpoint=checkpoint_path,  # overwrite to pass your own finetuned model.
        device=device,  # the device to run the model inference.
    )

    # Step 2: Get the instance segmentation for the given image.
    prediction = automatic_instance_segmentation(
        predictor=predictor,  # the predictor for the Segment Anything model.
        segmenter=segmenter,  # the segmenter class responsible for generating predictions.
        input_path=image,
        ndim=2,
    )

    return prediction

In [None]:

def show_label_plots(image,labels):
    import matplotlib.pyplot as plt
    from IPython.display import display
    import stackview
    fig, axs = plt.subplots(1, 3, figsize=(15,15))

    stackview.imshow(image, plot=axs[0], title='image', axes=True)
    stackview.imshow(labels, plot=axs[1], title='labels')

    stackview.imshow(image, plot=axs[2], continue_drawing=True)
    stackview.imshow(labels, plot=axs[2], alpha=0.4, title='image + labels')
    plt.show()


def process_omero_prediction_batch(
    dataset,
    output_folder: str,
    model_path: str,
    model_type: str = 'vit_l',
    model_id: int = None,
    batch_size: int = 3,
    channel: int = 0,
    timepoint: int = 0,
    z_slice: int = 0,
    show_results: bool = False,
    test_mode: bool = False,
    ):
    """
    This function processes images from an OMERO dataset using a fine-tuned Segment Anything Model (SAM)
    for automatic instance segmentation. It handles batch processing, prediction uploads, and tracking.
        dataset: OMERO dataset object containing images to process
        output_folder (str): Path to store temporary prediction files
        model_path (str): Path to the fine-tuned SAM model checkpoint file
        model_type (str, optional): SAM model type to use. Defaults to 'vit_l'
        model_id (int, optional): SAM model ID in OMERO. Defaults to None
        batch_size (int, optional): Number of images to process in each batch. Defaults to 3
        channel (int, optional): Channel index to segment. Defaults to 0
        timepoint (int, optional): Timepoint to process. Defaults to 0
        z_slice (int, optional): Z-slice index to process. Defaults to 0
        show_results (bool, optional): Whether to display results during processing. Defaults to False
        test_mode (bool, optional): Whether to run in test mode with user interaction. Defaults to False
    Returns:
        int: Table ID of the uploaded tracking table in OMERO
    """
    # Setup output directory
    output_path = os.path.join(output_folder, "predictions")
    
    # Remove directory if exists and create fresh
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.makedirs(output_path)
    
    # Get all images from dataset
    images_dataset = list(dataset.listChildren())
    total_batches = (len(images_dataset) + batch_size - 1) // batch_size

    # Create tracking dataframe
    df = pd.DataFrame(columns=[
        "image_id", "image_name", "channel", "timepoint", 
        "sam_model", "label_id", "roi_id"
    ])
    
    # Process images in batches
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(images_dataset))
        batch_images = images_dataset[start_idx:end_idx]
        
        # Process each image in batch
        for n, image in enumerate(batch_images):
            local_n = n
            
            # Get image plane
            pixels = image.getPrimaryPixels()
            img = pixels.getPlane(z_slice, channel, timepoint)
            print(img.shape)
            
            # Run automatic instance segmentation
            prediction = run_automatic_instance_segmentation(
                image=img, 
                checkpoint_path=model_path,
                model_type=model_type,
                device='cuda'
            )
            
            # Save prediction
            pred_file = os.path.join(output_path, f"pred_{local_n:05d}.tif")
            imageio.imwrite(pred_file, prediction)
            
            if show_results:
                show_label_plots(img, prediction)

            if test_mode:
                user_input = input("Press Enter to continue, or type 'stop' to halt execution: ")
                if user_input.lower() == 'stop':
                    raise SystemExit("User requested to stop execution")

            # Upload prediction and ROIs
            label_id, roi_id = upload_prediction_and_rois(
                conn, 
                image, 
                pred_file, 
                z_slice, 
                channel, 
                timepoint, 
                model_type,
                model_id
            )
            
            # Update tracking dataframe
            new_row = pd.DataFrame([{
                "image_id": image.getId(),
                "image_name": image.getName(),
                "channel": channel,
                "z_slice": z_slice,
                "timepoint": timepoint,
                "sam_model": model_type,
                "model_id": model_id,
                "label_id": label_id,
                "roi_id": roi_id
            }])
            df = pd.concat([df, new_row], ignore_index=True)
            
            # Clean up prediction file
            if os.path.exists(pred_file):
                os.remove(pred_file)
    
    # Upload tracking table
    table_id = ezomero.post_table(
        conn, 
        object_type="Dataset", 
        object_id=dataset.getId(), 
        table=df,
        title="micro_sam_prediction_data"
    )
    
    return table_id


def mask_to_contour(mask):
    """Converts a binary mask to a list of ROI coordinates.

    Args:
        mask (np.ndarray): binary mask

    Returns:
        list: list of ROI coordinates
    """
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def label_to_rois(label_img, z_slice, channel, timepoint):
    """
    Convert a 2D label image to OMERO ROI shapes
    
    Args:
        label_img (np.ndarray): 2D labeled image
        z_slice (int): Z-slice index
        channel (int): Channel index
        timepoint (int): Time point index
    
    Returns:
        list: List of OMERO shape objects
    """
    shapes = []
    unique_labels = np.unique(label_img)
    
    # Skip background (label 0)
    for label in unique_labels[1:]:
        # Create binary mask for this label
        mask = (label_img == label).astype(np.uint8)
        
        # Get contours
        contours = mask_to_contour(mask)
        
        # Convert each contour to polygon ROI
        for contour in contours:
            contour = contour[:, 0, :]  # Reshape to (N, 2)
            # Create polygon without text parameter
            poly = ezomero.rois.Polygon(
                points=contour,  # explicitly name the points parameter
                z=z_slice,
                c=channel,
                t=timepoint
            )
            shapes.append(poly)
    
    return shapes

def upload_prediction_and_rois(conn, image, pred_file, z_slice, channel, timepoint, model_type, model_id):
    """
    Upload prediction map and ROIs for automatically segmented image
    """
    # Upload prediction as attachment
    label_id = ezomero.post_file_annotation(
        conn,
        str(pred_file),
        ns='microsam.automatic_prediction',
        object_type="Image",
        object_id=image.getId(),
        description=f'SAM automatic instance segmentation ({model_id}) ({model_type})'
    )
    
    # Create ROIs from prediction
    pred_img = imageio.imread(pred_file)
    shapes = label_to_rois(pred_img, z_slice, channel, timepoint)
    
    if shapes:
        roi_id = ezomero.post_roi(
            conn,
            image.getId(),
            shapes,
            name=f'SAM_automatic_{model_id}_{model_type}',
            description='micro_sam.automatic_instance_segmentation'
        )
    else:
        roi_id = None
        
    return label_id, roi_id

In [None]:
model_folder = 'C:\\Users\\mwpaul\\micro-sam_models\\micro-sam-20250207_095503\\models\\checkpoints\\sam\\'
model_id = 'micro-sam-20250207_095503'
best_checkpoint = os.path.join(model_folder, "best.pt")
model_type = "vit_l"
channel = 3
batch_size = 1
timepoint = 0
z_slice = 5
table_id = process_omero_prediction_batch(
    dataset=dataset,
    output_folder=output_directory,
    model_path=best_checkpoint,
    model_type=model_type,
    model_id=model_id,
    batch_size=batch_size,
    channel=channel,
    timepoint=timepoint,
    z_slice=z_slice,
    test_mode=False,
    show_results=False,
)
print("Prediction Table ID:", table_id)