# 2D Stardist segmentation on 2D/3D/timelapse OMERO images

This notebook is used for Stardist segmentation. Some inspiration from the https://github.com/ome/omero-guide-cellprofiler/idr0002.ipynb

## TO DO
- Make a generic function for 2D segmentation for all slices independent of the shape of the image z,c,t
- Include a ID to all files uploaded to OMERO to make it more tracable
- Extend to handle multiple channels AND timepoints
- Check if we can overwrite label images if nesseary or ROIs


### Import Packages

In [None]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np

# Import Python System Packages
import os
import tempfile
import pandas
import warnings
from dotenv import load_dotenv

#stardist related
from stardist.models import StarDist2D
from csbdeep.utils import normalize
from stardist.plot import render_label
import matplotlib.pyplot as plt
from tifffile import imsave

#load stardist model
model = StarDist2D.from_pretrained('2D_versatile_fluo')

### Set Temp Output Directory

In [None]:
import datetime

new_output_directory = os.path.normcase(tempfile.mkdtemp())
print(new_output_directory)
#create unique job id for reference based on date and time
job_id = str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
print(job_id)

### Setup connection with OMERO

In [None]:
load_dotenv(override=True)

conn = BlitzGateway(host=os.environ.get("HOST"), username=os.environ.get("USER_NAME"), passwd=os.environ.get("PASSWORD"), secure=True)
print(conn.connect())
conn.c.enableKeepAlive(60)

### Get info from the dataset

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

### Run Stardist on the dataset

#### Get imageIDs to process

In [None]:
import importlib
import src.ProcessImage as ProcessImage
importlib.reload(ProcessImage)
import pandas as pd

#get list of images to process
if datatype == "plate":
    wells = list(plate.listChildren())
    well_count = len(wells)
    images = []
    for count, well in enumerate(wells):
        print('Well: %s/%s' % (count + 1, well_count), 'row:', well.row, 'column:', well.column)
        # Load all images for a well if there are multiple
        fields = well.countWellSample()
        for field in range(fields):
            print('Field:', field)
            image = well.getImage(field)
            images.append(image)
elif datatype == "dataset":
    images = list(dataset.listChildren())
elif datatype == "image":
    images = [image]   
     

#### process images


In [None]:
plate_statistics = []
for count,image in enumerate(images):
    print(f'Processing image {count+1}/{len(images)}: {image.getName()}')
    
    # Initialize processing
    img = ProcessImage.ProcessImage(conn, image, job_id, model)
    
    # Segment nuclei
    img.segment_nuclei(nucl_channel)
    
    # Save results
    img.save_segmentation_to_omero_as_attach(new_output_directory) 
    img.save_segmentation_to_omero_as_roi()
    #img.save_segmentation_to_omero_as_new_image(seg_img_name,desc)
    
    # Measure intensity
    img.measure_intensity(norm=False)
    all_statistics_df = img.get_measurements_to_df()
    
    #save intensity measurements to OMERO
    all_statistics_df['imageID'] = image.getId()

    plate_statistics.append(all_statistics_df)
    image_id = image.getId()
    tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
    print('Created table ID:', tabelid)
       
# Concatenate all statistics into a single DataFrame
plate_statistics_df = pd.concat(plate_statistics, ignore_index=True)
tabelid = ezomero.post_table(conn, object_type="Dataset", object_id=data_id, table = plate_statistics_df, title=f"Nuclei_measurements_{job_id}_{data_id}")

### explore results on single image

In [None]:
plate_statistics = []
#for count,image in enumerate(images):
image_number = 0
image = images[image_number]
count = image_number
print(f'Processing image {count+1}/{len(images)}: {image.getName()}')

# Initialize processing
img = ProcessImage.ProcessImage(conn, image, job_id, model)

# Segment nuclei
img.segment_nuclei(nucl_channel)

# Save results
#img.save_segmentation_to_omero_as_attach(new_output_directory) 
#img.save_segmentation_to_omero_as_roi()
#img.save_segmentation_to_omero_as_new_image(seg_img_name,desc)

# Measure intensity
img.measure_intensity(norm=False)
all_statistics_df = img.get_measurements_to_df()
#save intensity measurements to OMERO
#all_statistics_df['imageID'] = image.getId()

# image_id = image.getId()
# tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title=f"Nuclei_measurements_{job_id}_{image_id}")
# print('Created table ID:', tabelid)


In [None]:
#img.measure_intensity(norm=False)
#all_statistics_df = img.get_measurements_to_df()
#img.visualize_measurements(nucl_channel)
import stackview
np.shape(img.labels)
labels = img.labels
print(np.shape(labels))
import pyclesperanto_prototype as cle

image = img.get_image_stack()
#stackview.slice(image, continuous_update=True)


#stackview.curtain(image[:,0,:,:], img.labels)

plt = stackview.imshow(image[3,0,:,:], continue_drawing=True)
stackview.imshow(np.array(labels)[3,:,:], plot=plt, alpha=0.4, title='image + labels')

import tifffile
tifffile.imwrite('image.tif', image)
tifffile.imwrite('labels.tif', labels)

In [None]:
#https://github.com/haesleinhuepf/napari-skimage-regionprops/blob/master/demo/tables.ipynb
import numpy as np
import napari
import pandas
from napari_skimage_regionprops import regionprops_table, add_table, get_table


viewer = napari.Viewer()
viewer.add_image(image[3,0,:,:])
viewer.add_labels(np.array(labels)[3,:,:])

### Delete attachements from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

def ensure_list(obj):
    if not obj:
        return []
    if isinstance(obj, list):
        return obj
    return [obj]

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    to_delete = []
    for count in range(image_count):
        image = images[count]
        i = conn.getObject("Image", image.getId())
        print('Image Name:', i.getName())
        
        for ann in i.listAnnotations():
            link_id = ann.link.id  # sometimes single, sometimes list
            link_ids = ensure_list(link_id)

            for lid in link_ids:
                to_delete.append(lid)
    conn.deleteObjects("ImageAnnotationLink", to_delete, wait=True)

### Delete ROIs from project

In [None]:
datatype = "dataset" # "plate", "dataset", "image"
data_id = 	502

if datatype == "dataset":
    images = list(dataset.listChildren())
    image_count = len(images)
    plate_statistics = []
    for count in range(image_count):
        image = images[count]
        roi_service = conn.getRoiService()
        result = roi_service.findByImage(image.getId(), None)
        roi_ids = [roi.id.val for roi in result.rois]
        conn.deleteObjects("Roi", roi_ids)
