# 2D Stardist segmentation on 2D/3D/timelapse OMERO images

This notebook is used for Stardist segmentation. Some inspiration from the https://github.com/ome/omero-guide-cellprofiler/idr0002.ipynb

## TO DO
- Make a generic function for 2D segmentation for all slices independent of the shape of the image z,c,t
- Include a ID to all files uploaded to OMERO to make it more tracable
- Extend to handle multiple channels AND timepoints
- Check if we can overwrite label images if nesseary or ROIs


### Import Packages

In [None]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np
import datetime


# Import Python System Packages
import os
import tempfile
from dotenv import load_dotenv
import pandas as pd
#stardist related
from stardist.models import StarDist2D
from csbdeep.utils import normalize
from stardist.plot import render_label
from tifffile import imsave

#load stardist model
model = StarDist2D.from_pretrained('2D_versatile_fluo')

### Set Temp Output Directory

In [None]:

new_output_directory = os.path.normcase(tempfile.mkdtemp())
print(new_output_directory)
#create unique job id for reference based on date and time
job_id = str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
print(job_id)

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)

conn = ezomero.connect(user=os.environ.get("USER_NAME"),password=os.environ.get("PASSWORD"),group=os.environ.get("GROUP"),host=os.environ.get("HOST"),port=os.environ.get("PORT"),secure=True)
connection_status = conn.connect()
if connection_status:
    print("Connected to OMERO Server")
else:
    print("Connection to OMERO Server Failed")
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "project", "plate", "dataset" or "image"
data_id = 	1112
nucl_channel = 0

def print_object_details(conn, obj, datatype):
    """Print detailed information about OMERO objects"""
    print(f"\n{datatype.capitalize()} Details:")
    print(f"- Name: {obj.getName()}")
    print(f"- ID: {obj.getId()}")
    print(f"- Owner: {obj.getOwner().getFullName()}")
    print(f"- Group: {obj.getDetails().getGroup().getName()}")
    
    if datatype == "project":
        datasets = list(obj.listChildren())
        dataset_count = len(datasets)
        total_images = sum(len(list(ds.listChildren())) for ds in datasets)
        print(f"- Number of datasets: {dataset_count}")
        print(f"- Total images: {total_images}")
        
    elif datatype == "plate":
        wells = list(obj.listChildren())
        well_count = len(wells)
        print(f"- Number of wells: {well_count}")
        
    elif datatype == "dataset":
        images = list(obj.listChildren())
        image_count = len(images)
        # Get project info if dataset is in a project
        projects = obj.getParent()
        if projects:
            print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Project: None (orphaned dataset)")
        print(f"- Number of images: {image_count}")
        
    elif datatype == "image":
        size_x = obj.getSizeX()
        size_y = obj.getSizeY()
        size_z = obj.getSizeZ()
        size_c = obj.getSizeC()
        size_t = obj.getSizeT()
        # Get dataset info if image is in a dataset
        datasets = obj.getParent()
        if datasets:
            print(f"- Dataset: {datasets.getName()} (ID: {datasets.getId()})")
            # Get project info if dataset is in a project
            projects = datasets.getParent()
            if projects:
                print(f"- Project: {projects.getName()} (ID: {projects.getId()})")
        else:
            print("- Dataset: None (orphaned image)")
        print(f"- Dimensions: {size_x}x{size_y}")
        print(f"- Z-stack: {size_z}")
        print(f"- Channels: {size_c}")
        print(f"- Timepoints: {size_t}")

# Validate that data_id matches datatype and print details
if datatype == "project":
    project = conn.getObject("Project", data_id)
    if project is None:
        raise ValueError(f"Project with ID {data_id} not found")
    print_object_details(conn, project, "project")
    
elif datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    if plate is None:
        raise ValueError(f"Plate with ID {data_id} not found")
    print_object_details(conn, plate, "plate")
    
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    if dataset is None:
        raise ValueError(f"Dataset with ID {data_id} not found")
    print_object_details(conn, dataset, "dataset")
    
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    if image is None:
        raise ValueError(f"Image with ID {data_id} not found")
    print_object_details(conn, image, "image")

else:
    raise ValueError("Invalid datatype specified")


### Run Stardist on the dataset

#### Get imageIDs to process
- make sure that all images in your Project/dataset/plate are same type images e.g. time series, z-stacks and or multichannel

In [None]:
#for development of the src.ProcessImage module
import importlib
import src.ProcessImage as ProcessImage
importlib.reload(ProcessImage)

# Get list of images to process based on datatype
if datatype == "plate":
    images_ids = ezomero.get_image_ids(conn, plate=data_id)
    images = [conn.getObject("Image", id) for id in images_ids]
    print(f"Processing {len(images)} images from plate {data_id}")
elif datatype == "dataset":
    images_ids = ezomero.get_image_ids(conn, dataset=data_id)
    images = [conn.getObject("Image", id) for id in images_ids]
    print(f"Processing {len(images)} images from dataset {data_id}")
elif datatype == "image":
    images = [conn.getObject("Image", data_id)]
    print(f"Processing 1 image with ID {data_id}")
elif datatype == "project":
    images_ids = ezomero.get_image_ids(conn, project=data_id)
    images = [conn.getObject("Image", id) for id in images_ids]
    print(f"Processing {len(images)} images from project {data_id}")

#### process images


In [None]:
plate_statistics = []
for count,image in enumerate(images):
    print(f'Processing image {count+1}/{len(images)}: {image.getName()}')
    
    # Initialize processing
    img = ProcessImage.ProcessImage(conn, image, job_id, model)
    
    # Segment nuclei
    img.segment_nuclei(nucl_channel)
    
    # Save results
    img.save_segmentation_to_omero_as_attach(new_output_directory) 
    img.save_segmentation_to_omero_as_roi()
    #img.save_segmentation_to_omero_as_new_image(seg_img_name,desc)
    
    # Measure intensity
    img.measure_intensity(norm=False)
    all_statistics_df = img.get_measurements_to_df()
    
    #save intensity measurements to OMERO
    all_statistics_df['imageID'] = image.getId()

    plate_statistics.append(all_statistics_df)
    image_id = image.getId()
    tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
    print('Created table ID:', tabelid)
       
# Concatenate all statistics into a single DataFrame
plate_statistics_df = pd.concat(plate_statistics, ignore_index=True)
tabelid = ezomero.post_table(conn, object_type="Dataset", object_id=data_id, table = plate_statistics_df, title=f"Nuclei_measurements_{job_id}_{data_id}")

### explore Stardist segmentation on a single image

In [None]:
plate_statistics = []
#for count,image in enumerate(images):
image_number = 0
image = images[image_number]
count = image_number
print(f'Processing image {count+1}/{len(images)}: {image.getName()}')

# Initialize processing
img = ProcessImage.ProcessImage(conn, image, job_id, model)

# Segment nuclei
img.segment_nuclei(nucl_channel)

# Save results
#img.save_segmentation_to_omero_as_attach(new_output_directory) 
#img.save_segmentation_to_omero_as_roi()
#img.save_segmentation_to_omero_as_new_image(seg_img_name,desc)

# Measure intensity
img.measure_intensity(norm=False)
all_statistics_df = img.get_measurements_to_df()
#save intensity measurements to OMERO
#all_statistics_df['imageID'] = image.getId()

# image_id = image.getId()
# tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
# print('Created table ID:', tabelid)


In [None]:
#img.measure_intensity(norm=False)
#all_statistics_df = img.get_measurements_to_df()
#img.visualize_measurements(nucl_channel)
import stackview
np.shape(img.labels)
labels = img.labels
print(np.shape(labels))
import pyclesperanto_prototype as cle

image = img.get_image_stack()
#stackview.slice(image, continuous_update=True)


#stackview.curtain(image[:,0,:,:], img.labels)

plt = stackview.imshow(image[3,0,:,:], continue_drawing=True)
stackview.imshow(np.array(labels)[3,:,:], plot=plt, alpha=0.4, title='image + labels')

import tifffile
tifffile.imwrite('image.tif', image)
tifffile.imwrite('labels.tif', labels)

In [None]:
#https://github.com/haesleinhuepf/napari-skimage-regionprops/blob/master/demo/tables.ipynb
import numpy as np
import napari
import pandas
from napari_skimage_regionprops import regionprops_table, add_table, get_table


viewer = napari.Viewer()
viewer.add_image(image[3,0,:,:])
viewer.add_labels(np.array(labels)[3,:,:])

### Delete attachements from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

def ensure_list(obj):
    if not obj:
        return []
    if isinstance(obj, list):
        return obj
    return [obj]

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    to_delete = []
    for count in range(image_count):
        image = images[count]
        i = conn.getObject("Image", image.getId())
        print('Image Name:', i.getName())
        
        for ann in i.listAnnotations():
            link_id = ann.link.id  # sometimes single, sometimes list
            link_ids = ensure_list(link_id)

            for lid in link_ids:
                to_delete.append(lid)
    conn.deleteObjects("ImageAnnotationLink", to_delete, wait=True)

### Delete ROIs from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    for count in range(image_count):
        image = images[count]
        roi_service = conn.getRoiService()
        result = roi_service.findByImage(image.getId(), None)
        roi_ids = [roi.id.val for roi in result.rois]
        conn.deleteObjects("Roi", roi_ids)
