In [1]:
import napari
from skimage.io import imread
import numpy as np
import pandas as pd
import time
from napari_clusters_plotter._utilities import get_nice_colormap
from vispy.color import Color
import pyclesperanto_prototype as cle
from skimage.io.collection import alphanumeric_key
from dask import delayed
import dask.array as da
from glob import glob

def make_dask_stack(folder, prefix):
    filenames = sorted(glob(folder + prefix + "*.tif"), key=alphanumeric_key)
    # read the first file to get the shape and dtype
    # ASSUMES THAT ALL FILES SHARE THE SAME SHAPE/TYPE
    sample = imread(filenames[0])

    lazy_imread = delayed(imread)  # lazy reader
    lazy_arrays = [lazy_imread(fn) for fn in filenames]
    dask_arrays = [
        da.from_delayed(delayed_reader, shape=sample.shape, dtype=sample.dtype)
        for delayed_reader in lazy_arrays
    ]
    # Stack into one large dask.array
    return da.stack(dask_arrays, axis=0)

def dask_cluster_image_timelapse(function, label_image, prediction_list_list):
    import dask.array as da
    from dask import delayed

    sample = label_image[0]

    lazy_cluster_image = delayed(function)  # lazy processor
    lazy_arrays = [
        lazy_cluster_image(frame, preds)
        for frame, preds in zip(label_image, prediction_list_list)
    ]
    dask_arrays = [
        da.from_delayed(delayed_reader, shape=sample.shape, dtype=sample.dtype)
        for delayed_reader in lazy_arrays
    ]
    # Stack into one large dask.array
    stack = da.stack(dask_arrays, axis=0)

    return stack

def generate_cluster_image(label_image, predictionlist):
    """
    Returns a label image where each label value corresponds
    to the cluster identity defined by the predictionlist.
    it is assumed that len(predictionlist) == max(label_image)

    Parameters
    ----------
    label_image: ndarray or dask array
        Label image used for cluster predictions
    predictionlist: array
        Array containing cluster identities for each label
    """
    # reforming the prediction list this is done to account
    # for cluster labels that start at 0 conviniently hdbscan
    # labelling starts at -1 for noise, removing these from
    # the labels
    predictionlist_new = np.array(predictionlist) + 1
    predictionlist_new = np.insert(predictionlist_new, 0, 0)

    # loading data into gpu
    clelist = cle.push(predictionlist_new)
    gpu_labelimage = cle.push(label_image)

    # generation of cluster label image
    parametric_image = cle.replace_intensities(gpu_labelimage, clelist)
    gpu_labelimage = None
    clelist = None

    # retrieving the gpu image
    output = cle.pull(parametric_image).astype("uint32")
    parametric_image = None

    return output

def generate_cluster_image_np(label_image, predictionlist):
    predictionlist_new = np.array(predictionlist) + 1
    predictionlist_new = np.insert(predictionlist_new, 0, 0)
    return np.take(predictionlist_new, label_image)

def generate_label_to_cluster_color_mapping(label_list, predictionlist, colormap_dict):
    predictionlist_new = np.array(predictionlist) + 1
    mapping = {0: [0,0,0,0]}
    for label,prediction in zip(label_list,predictionlist_new):
        mapping[label] = colormap_dict[prediction]
        
    return mapping

In [4]:
finsterwalde_path = 'C:/Users/ryans/Documents/output data (big)/Finsterwalde Gastrulation Labels (new timeframe)/'
regprops_name = 'finsterwalder_master_regprops_ncp.csv'
regprops = pd.read_csv(finsterwalde_path + regprops_name)

max_timepoint = regprops['frame'].max() + 1
prediction_lists_per_timepoint = [
    regprops.loc[regprops['frame'] == i]['MANUAL_CLUSTER_ID'].to_numpy()
    for i in range(max_timepoint)
]

labels_list_pre_timepoint = [
    regprops.loc[regprops['frame'] == i]['label'].to_numpy()
    for i in range(max_timepoint)
]

labels_dask = make_dask_stack(finsterwalde_path, 'workflow')

colors =get_nice_colormap()
cmap = [Color(hex_name).RGBA.astype("float") / 255 for hex_name in colors]
cmap_dict = {
    int(prediction + 1): (
        cmap[int(prediction) % len(cmap)]
        if prediction >= 0
        else [0, 0, 0, 0]
    )
    for prediction in regprops['MANUAL_CLUSTER_ID'].to_numpy()
}
# take care of background label
cmap_dict[0] = [0, 0, 0, 0]

viewer = napari.Viewer()
viewer.add_labels(labels_dask)
labels_dask

v0.5.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  self.tools_menu = ToolsMenu(self, self.qt_viewer.viewer)


Unnamed: 0,Array,Chunk
Bytes,8.74 GiB,279.60 MiB
Shape,"(32, 297, 964, 512)","(1, 297, 964, 512)"
Count,96 Tasks,32 Chunks
Type,uint16,numpy.ndarray
"Array Chunk Bytes 8.74 GiB 279.60 MiB Shape (32, 297, 964, 512) (1, 297, 964, 512) Count 96 Tasks 32 Chunks Type uint16 numpy.ndarray",32  1  512  964  297,

Unnamed: 0,Array,Chunk
Bytes,8.74 GiB,279.60 MiB
Shape,"(32, 297, 964, 512)","(1, 297, 964, 512)"
Count,96 Tasks,32 Chunks
Type,uint16,numpy.ndarray


## Pyclesperanto time

In [2]:
starttime = time.process_time_ns()
cl_image = dask_cluster_image_timelapse(function=generate_cluster_image,prediction_list_list=)

viewer.add_labels(cl_image, color=cmap_dict)

endtime = time.process_time_ns()
cle_time = endtime-starttime
print(f'pyclesperanto took {cle_time/1_000_000} ms')

pyclesperanto took 828.125 ms


## Numpy Time

In [3]:
starttime = time.process_time_ns()
np_image = generate_cluster_image_np(label_image,predictionlist=cluster_id)

viewer.add_labels(np_image, color=cmap_dict)

endtime = time.process_time_ns()
np_time = endtime - starttime

print(f'numpy took {np_time/1_000_000} ms')


numpy took 968.75 ms


## Colormap Time
### with adding a layer

In [4]:
'''
starttime = time.process_time_ns()

mapping_layer = viewer.add_labels(label_image,name = 'cluster id mapping')
mapping = generate_label_to_cluster_color_mapping(
    label_list=labels_list, 
    predictionlist= cluster_id, 
    colormap_dict= cmap_dict
)
mapping_layer.color = mapping

endtime = time.process_time_ns()
mapping_time = endtime - starttime

print(f'mapping took {mapping_time/1_000_000} ms')
'''

mapping took 187.5 ms
