### Fine tuning SAM with OMERO data

TO DO
- clean up tmp files when not neccesary anymore

In [23]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np

# Import Python System Packages
import os
import tempfile
import pandas as pd
import warnings
from tifffile import imwrite

#micro-sam related imports
from micro_sam.sam_annotator import annotator_2d

### Setup connection with OMERO

In [24]:
conn = BlitzGateway(host='localhost', username='root', passwd='omero', secure=True)
print(conn.connect())
conn.c.enableKeepAlive(60)

True


### Get info from the dataset

In [25]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 253
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

Dataset Name:  day7_testing


### Load images from OMERO and open in napari with micro-sam annotator

In [26]:
new_output_directory = os.path.normcase(tempfile.mkdtemp())
print('Output Directory: ', new_output_directory)

Output Directory:  c:\users\mwpaul\appdata\local\temp\tmp6oy21oa0


In [27]:
import napari
from napari.settings import get_settings
import zipfile
from micro_sam.sam_annotator import image_series_annotator
from micro_sam.util import precompute_image_embeddings

def zip_directory(folder_path, zip_file):
    for folder_name, subfolders, filenames in os.walk(folder_path):
        for filename in filenames:
            # Create complete filepath of file in directory
            file_path = os.path.join(folder_name, filename)
            # Add file to zip
            zip_file.write(file_path)
def interleave_arrays(train_images, validate_images):
    """
    Interleave two arrays of images in the pattern: train[0], validate[0], train[1], validate[1], ...
    If arrays are of unequal length, remaining elements are appended at the end.
    """
    # Create empty list to store interleaved images
    interleaved = []
    sequence = []
    # Get the length of the longer array
    max_len = max(len(train_images), len(validate_images))
    
    # Interleave the arrays
    for i in range(max_len):
        # Add train image if available
        if i < len(train_images):
            interleaved.append(train_images[i])
            sequence.append(0)
        # Add validate image if available
        if i < len(validate_images):
            interleaved.append(validate_images[i])
            sequence.append(1)
    
    return np.array(interleaved), np.array(sequence)


##imput parameters
model_type = 'vit_l'
train_n = 3
validate_n = 3
channel = 3 #which channel to segment starting from 0
timepoint = 0
z_slice = 5 #for now pick one slice but TODO add option to pick multiple slices by giving a list of z slices

#set napari settings
settings = get_settings()
settings.application.ipy_interactive = False
#for now just get first image in dataset
if datatype == "dataset":
    images_dataset = list(dataset.listChildren())
    images = []
    for image in images_dataset:
        pixels = image.getPrimaryPixels()
        img = pixels.getPlane(z_slice, channel, timepoint) #(z, c, t) 
        #precompute_image_embeddings(img, model_type='vit_l', save_path=new_output_directory)
        images.append(img)
    #start napari viewer
    viewer = napari.Viewer()
    output_folder = new_output_directory
    os.makedirs(os.path.join(output_folder,"output"), exist_ok=True)
    os.makedirs(os.path.join(output_folder,"embed"), exist_ok=True)
    image_series_annotator(images, model_type=model_type, viewer=viewer,embedding_path=os.path.join(output_folder,"embed"),output_folder=os.path.join(output_folder,"output"))
    napari.run()
    # Wait until the napari viewer is closed by the user
    #TODO clean up napari output in the notebook

Precomputation took 183.96702575683594 seconds (= 03:04 minutes)
The first image to annotate is image number 0


Epoch 0:   2%|▏         | 44/2000 [2:36:50<116:12:33, 213.88s/it]
Epoch 0:   5%|▌         | 100/2000 [4:51:22<92:16:08, 174.83s/it]
  napari.run()


# upload embeddings

In [7]:
if datatype == "dataset":    
    embed_ids = []
    for n,image in enumerate(combined_images):
        embed_file = os.path.join(output_folder,"embed",f"embedding_{n:05d}.zarr")#fixed leading zeros
        #zip zarr directory
        zip_path = os.path.join(output_folder,"embed",f"embedding_{n:05d}.zip")
        zip_file = zipfile.ZipFile(zip_path, 'w')
        zip_directory(embed_file, zip_file)
        zip_file.close()
        #upload zip file as attachment to image
        file_annotation_id = ezomero.post_file_annotation(
            conn,
            str(zip_path),
            ns='microsam.embeddings',
            object_type="Image",
            object_id=image.getId(),
            description='image embedding') #TODO add specification of type of embedding etc
        embed_ids.append(file_annotation_id)
    #upload annotations as attachment to images
    label_ids = []
    for n,image in enumerate(combined_images):
        label_file = os.path.join(output_folder,"output",f"seg_{n:05d}.tif")#fixed leading zeros
        #upload label file as attachment to image
        file_annotation_id = ezomero.post_file_annotation(
            conn,
            str(label_file),
            ns='microsam.labelimage',
            object_type="Image",
            object_id=image.getId(),
            description='label image') #TODO add specification of type of label embedding etc
        label_ids.append(file_annotation_id)

#upload table with training data
df = pd.DataFrame(columns=["image_id", "image_name", "train", "validate", "channel", "timepoint", "sam_model"])
for n, image in enumerate(combined_images):
    new_row = pd.DataFrame([{
        "image_id": image.getId(),
        "image_name": image.getName(),
        "train": combine_images_sequence[n] == 0,
        "validate": combine_images_sequence[n] == 1,
        "channel": channel,
        "timepoint": timepoint,
        "sam_model": model_type,
        "embed_id": embed_ids[n],
        "label_id": label_ids[n]
    }])
    df = pd.concat([df, new_row], ignore_index=True)
tabelid = ezomero.post_table(conn, object_type="Dataset", object_id=data_id, table = df,title="micro_sam_training_data")
print("Table ID: ", tabelid)

Table ID:  1343
