In [1]:
from elf.io import open_file
from  pathlib import Path
import numpy as np


def read_dataset(ds_path,organell_filename, scaling_factor =3):
    """Read a dataset and select the correct file

    :param ds_path: ds_path should be the path to the folder inside `images`that conatins the .n5 files.
    :type ds_path: str
    :param organell_filename: The name of the file that contains the organell data
    :type organell_filename: str
    :param scaling_factor: Can be 0,1,2 or 3. The lower the number the higher the resolution, defaults to 3.
    :type scaling_factor: int, optional

    """
    ds_path = Path(ds_path)

    data_key = f"setup0/timepoint0/s{scaling_factor}"
    data_path = ds_path/ organell_filename
    print(data_path)
    with open_file(str(data_path), 'r') as f:
        ds = f[data_key]
    return ds

In [2]:
ds_path = "/home/gwydion/SSC/cebra/mobie-data-testing/data/cebra_em_example/seg_er_5nm_mito_10nm/CebraEM/images/bdv-n5"
organell_filename = "mito-it00_b0_6_stitched.n5"
scaling_factor = 2
ds =  read_dataset(ds_path,organell_filename, scaling_factor =scaling_factor)

/home/gwydion/SSC/cebra/mobie-data-testing/data/cebra_em_example/seg_er_5nm_mito_10nm/CebraEM/images/bdv-n5/mito-it00_b0_6_stitched.n5


In [3]:
print(ds.attrs["downsamplingFactors"])
print(np.unique(ds))
# each of these should be a unique mitochondria

[4, 4, 4]
[  0   2   3   4   5   6   7   8   9  10  11  14  15  17  18  21  26  32
  33  34  38  41  43  46  47  49  51  56  57  65  69  70  72  73  75  76
  77  81  88  96  99 102 107 124 130 142 163 166 169 170 171 172 173 174
 180 187 188 219 220 221 240 251 252 258 260 288 295 297 299 301 307 332
 335 339 343 345 346 347 349 354 358 361 366 368 369 370 374 380 381 382
 383 392 400 402 406 409 414 417 420 424 437 438 445 448 463 466 471]


In [4]:
def filter_ds(ds, filter_value):
    # filter for one value
    ds_filtered = ds[:]

    mask = (ds_filtered == filter_value)
    ds_filtered[~mask] = 0
    ds_filtered.shape
    return (ds_filtered)


In [5]:
# https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
from collections import defaultdict
from skimage import measure
import pandas as pd


def get_ds_properties(ds, num_pixels_threshhold=10, resolution = (5,5,5)):
    properties = measure.regionprops(ds,spacing=resolution)

    prop_dict = defaultdict(dict)
    for prop in properties:
        if prop.num_pixels < num_pixels_threshhold:
            continue
        label = prop.label
        prop_dict[label]["label"]=prop.label

        # as far as i can tell this area attribute is the volume when given a 3d array
        prop_dict[label]["volume_voxels"]= prop.area
        bbox = np.asarray((prop.bbox[:3], prop.bbox[3:]))
        prop_dict[label]["bbox"]= bbox
        prop_dict[label]["bbox_dim_nm"]= bbox[1]*5-bbox[0]*5
        bbox_volume = np.prod(bbox[1]*5-bbox[0]*5)
        prop_dict[label]["bbox_vol_nm"]= bbox_volume


        prop_dict[label]["centroid"]= prop.centroid
        prop_dict[label]["solidity"]= prop.solidity

        prop_dict[label]["coords"]= prop.coords


        # calculate surface area from mesh
        ds_filtered = filter_ds(ds,label )
        verts, faces, _, _  = measure.marching_cubes(ds_filtered[:], spacing=(5,5,5))
        area = measure.mesh_surface_area(verts, faces)
        prop_dict[label]["area_mesh_nm"]= area
        #attach the actual mesh to the df (likely not needed)
        # prop_dict[label]["mesh_nm"]= [verts, faces]

    df = pd.DataFrame(prop_dict).T
    df.index.rename("Label", inplace=True)

    return df


In [6]:
df = get_ds_properties(ds, num_pixels_threshhold=5, resolution = (5,5,5))
df.columns

Index(['label', 'volume_voxels', 'bbox', 'bbox_dim_nm', 'bbox_vol_nm',
       'centroid', 'solidity', 'coords', 'area_mesh_nm'],
      dtype='object')

In [7]:
df[["bbox_dim_nm", "bbox_vol_nm","volume_voxels", "area_mesh_nm"]][:]

Unnamed: 0_level_0,bbox_dim_nm,bbox_vol_nm,volume_voxels,area_mesh_nm
Label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,"[65, 75, 35]",170625,33125,5726.575581
3,"[40, 30, 30]",36000,15125,1914.016792
4,"[55, 85, 55]",257125,65500,10320.00622
5,"[115, 220, 160]",4048000,152125,28508.995929
6,"[75, 155, 185]",2150625,143375,24041.23269
...,...,...,...,...
445,"[30, 30, 20]",18000,4000,1669.86056
448,"[145, 190, 70]",1928500,145875,26740.167022
463,"[85, 120, 95]",969000,110750,17050.293562
466,"[305, 300, 345]",31567500,697750,115733.638894


In [8]:
def dash_plot(ds):
    from dash import Dash, dcc, html, Input, Output
    import plotly.express as px

    app = Dash("test")


    app.layout = html.Div([
        html.H4('Cell Stacks'),
        dcc.Graph(id="graph"),
        html.P("Slice:"),
        dcc.Slider(
            id='slices',
            min=0,
            max=ds.shape[2],
            step=1,
            value=1
        )
    ])


    @app.callback(
    Output("graph", "figure"), 
    Input("slices", "value"))
    def filter_heatmap(slice):
        ds_slice = ds[:,:,slice] # replace with your own data source
        fig = px.imshow(ds_slice)
        return fig


    app.run_server(debug=True, port = 8083)
dash_plot(ds)