# 2D Stardist segmentation on 2D/3D/timelapse OMERO images

This notebook is used for Stardist segmentation. Some inspiration from the https://github.com/ome/omero-guide-cellprofiler/idr0002.ipynb

## TO DO
- Make a generic function for 2D segmentation for all slices independent of the shape of the image z,c,t
- Include a ID to all files uploaded to OMERO to make it more tracable
- Extend to handle multiple channels AND timepoints
- Check if we can overwrite label images if nesseary or ROIs


### Import Packages

In [None]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np

# Import Python System Packages
import os
import tempfile
import pandas
import warnings

#stardist related
from stardist.models import StarDist2D
from csbdeep.utils import normalize
from stardist.plot import render_label
import matplotlib.pyplot as plt
from tifffile import imsave

#load stardist model
model = StarDist2D.from_pretrained('2D_versatile_fluo')

### Set Temp Output Directory

In [None]:
import datetime

new_output_directory = os.path.normcase(tempfile.mkdtemp())
print(new_output_directory)
#create unique job id for reference based on date and time
job_id = str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
print(job_id)

### Setup connection with OMERO

In [None]:
conn = BlitzGateway(host='localhost', username='root', passwd='omero', secure=True)
print(conn.connect())
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

### Run Stardist on the dataset

#### Function definitions

In [None]:
import pyclesperanto_prototype as cle
import pandas as pd

def measure_intensity(pixels, labels, size_z, size_t, size_c):
    all_statistics = []
    if size_z > 1 and size_t > 1:
        #raise error that time series and z-stack data is not supported
        raise ValueError("Time series and z-stack data is not supported (yet)")
    elif size_t > 1:
        for t, label in zip(range(size_t), labels):
            for c in range(size_c):
                statistics = cle.statistics_of_labelled_pixels(pixels.getPlane(0, c, t), label)
                statistics = pd.DataFrame(statistics)
                statistics['z'] = 0
                statistics['t'] = t
                statistics['channel'] = c
                all_statistics.append(statistics)   
    elif size_z > 1:
        for z, label in zip(range(size_z), labels):
            for c in range(size_c):
                statistics = cle.statistics_of_labelled_pixels(pixels.getPlane(z, c, 0), label)
                statistics = pd.DataFrame(statistics)
                statistics['z'] = z
                statistics['t'] = 0
                statistics['channel'] = c
                all_statistics.append(statistics)
    else:
        statistics = cle.statistics_of_labelled_pixels(pixels.getPlane(1, 0, 0), labels)
        statistics['z'] = 0
        statistics['t'] = t
        all_statistics.append(statistics)
    
    # Concatenate all statistics into a single DataFrame
    all_statistics_df = pd.concat(all_statistics, ignore_index=True)
    
    return all_statistics_df


#### code to run the analysis

In [None]:
import importlib
import src.ProcessImage as ProcessImage
importlib.reload(ProcessImage)

#get list of images to process
if datatype == "plate":
    wells = list(plate.listChildren())
    well_count = len(wells)
    images = []
    for count, well in enumerate(wells):
        print('Well: %s/%s' % (count + 1, well_count), 'row:', well.row, 'column:', well.column)
        # Load all images for a well if there are multiple
        fields = well.countWellSample()
        for field in range(fields):
            print('Field:', field)
            image = well.getImage(field)
            images.append(image)
elif datatype == "dataset":
    images = list(dataset.listChildren())
elif datatype == "image":
    images = [image]   

#process images
plate_statistics = []
for count,image in enumerate(images):
    #save stack back to OMERO same project only add _nucleisegmentation to the name
    seg_img_name = image.getName() + "_nucleisegmentation"
    desc = "Stardist nuclei segmentation"
    img = ProcessImage.ProcessImage(conn,image,job_id,model)
    img.segment_nuclei(nucl_channel)
    img.save_segmentation_to_omero_as_attach(new_output_directory,desc)
    #img.save_segmentation_to_omero_as_new_image(seg_img_name,desc)
    img.save_segmentation_to_omero_as_roi()
    all_statistics_df = measure_intensity(img._pixels, img._labels, img._size_z, img._size_t, img._size_c)
    all_statistics_df['imageID'] = image.getId()
    plate_statistics.append(all_statistics_df)
    image_id = image.getId()
    tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
    print('Created table ID:', tabelid)
# Concatenate all statistics into a single DataFrame
plate_statistics_df = pd.concat(plate_statistics, ignore_index=True)
tabelid = ezomero.post_table(conn, object_type="Dataset", object_id=data_id, table = plate_statistics_df, title=f"Nuclei_measurements_{job_id}_{data_id}")


####old code remove asap
if datatype == "plate":
    plate_statistics = []
    wells = list(plate.listChildren())
    # use the first 3 wells only
    wells = wells[0:2] # for testing
    well_count = len(wells)
    for count, well in enumerate(wells):
            print('Well: %s/%s' % (count + 1, well_count), 'row:', well.row, 'column:', well.column)
            # Load all images for a well if there are multiple
            fields = well.countWellSample()
            for field in range(fields):
                print('Field:', field)
                image = well.getImage(field)
                #save stack back to OMERO same project only add _nucleisegmentation to the name
                new_img_name = image.getName() + "_nucleisegmentation"
                desc = "Stardist nuclei segmentation"
                img = ProcessImage.ProcessImage(image, conn, model)
                print('image dimensions:', img._size_z, img._size_c,img._size_t)
                img.segment_nuclei(nucl_channel)
                img.save_segmentation_to_omero_as_attach(new_output_directory,desc)
                #img.save_segmentation_to_omero_as_new_image(new_img_name,desc)
                img.save_segmentation_to_omero_as_roi()
                all_statistics_df = measure_intensity(img._pixels, img._labels, img._size_z, img._size_t, img._size_c)
                all_statistics_df['well'] = well.getId()
                plate_statistics.append(all_statistics_df)
                tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title="Nuclei_measurements")
                print('Created table ID:', tabelid)
    # Concatenate all statistics into a single DataFrame
    plate_statistics_df = pd.concat(plate_statistics, ignore_index=True)
    tabelid = ezomero.post_table(conn, object_type="Plate", object_id=data_id, table = plate_statistics,title="Nuclei_measurements")
            
elif datatype == "dataset":
    images = list(dataset.listChildren())
    # use the first 3 images only
    #images = images[0:3]
    image_count = len(images)
    plate_statistics = []
    for count in range(image_count):
        image = images[count]
        #save stack back to OMERO same project only add _nucleisegmentation to the name
        new_img_name = image.getName() + "_nucleisegmentation"
        desc = "Stardist nuclei segmentation"
        img = ProcessImage.ProcessImage(image, conn,model)
        img.segment_nuclei(nucl_channel)
        img.save_segmentation_to_omero_as_attach(new_output_directory,desc)
        #img.save_segmentation_to_omero_as_new_image(desc)
        img.save_segmentation_to_omero_as_roi()
        all_statistics_df = measure_intensity(img._pixels, img._labels, img._size_z, img._size_t, img._size_c)
        all_statistics_df['imageID'] = image.getId()
        plate_statistics.append(all_statistics_df)
        image_id = image.getId()
        tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
        print('Created table ID:', tabelid)
    # Concatenate all statistics into a single DataFrame
    plate_statistics_df = pd.concat(plate_statistics, ignore_index=True)
    tabelid = ezomero.post_table(conn, object_type="Dataset", object_id=data_id, table = plate_statistics_df, title=f"Nuclei_measurements_{job_id}_{data_id}")
            

### Delete attachements from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

def ensure_list(obj):
    if not obj:
        return []
    if isinstance(obj, list):
        return obj
    return [obj]

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    to_delete = []
    for count in range(image_count):
        image = images[count]
        i = conn.getObject("Image", image.getId())
        print('Image Name:', i.getName())
        
        for ann in i.listAnnotations():
            link_id = ann.link.id  # sometimes single, sometimes list
            link_ids = ensure_list(link_id)

            for lid in link_ids:
                to_delete.append(lid)
    conn.deleteObjects("ImageAnnotationLink", to_delete, wait=True)

### Delete ROIs from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    for count in range(image_count):
        image = images[count]
        roi_service = conn.getRoiService()
        result = roi_service.findByImage(image.getId(), None)
        roi_ids = [roi.id.val for roi in result.rois]
        conn.deleteObjects("Roi", roi_ids)
