# STAPL-3D segmentation demo

This notebook demonstrates the core components of the STAPL-3D segmentation pipeline: **blockwise segmentation** and **zipping**.

If you did not follow the STAPL-3D README: please find STAPL-3D and the installation instructions [here](https://github.com/RiosGroup/STAPL3D) before doing this demo.

Because STAPL-3D is all about big datafiles, we provide small cutouts and precomputed summary data that will be downloaded while progressing through the notebook.

Let's start with some general settings and imports.

In [None]:
# Show all output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Imports.
import os
import yaml
import urllib.request
from pprint import pprint

# Yaml printing function.
def yprint(ydict):
    """Print dictionary in yaml formatting."""
    print(yaml.dump(ydict, default_flow_style=False))


First, define where you want the data to be downloaded by changing *projectdir*; default is the current demo directory. The name of the dataset is *'HFK16w'* (for Human Fetal Kidney - 16 weeks). We create a directory for the dataset and jump to it.

In [None]:
projectdir = '.'
dataset = 'HFK16w'

datadir = os.path.join(projectdir, dataset)

os.makedirs(datadir, exist_ok=True)
os.chdir(datadir)
f'working in directory: {os.path.abspath(".")}'


We define STAPL3D parameters preferably using a [yaml](https://yaml.org) parameter file. It has a simple structure and can be parsed in Python and `bash`. We will download the example, read it into a dictionary structure, list all entries and show the entry that contains information on the default directory structure for STAPL3D. 

In [None]:
parameter_file = f'{dataset}.yml'

# Download the yml-file.
if not os.path.exists(parameter_file):
    url = 'https://surfdrive.surf.nl/files/index.php/s/Ubx9wVon5CIuIzo/download'
    urllib.request.urlretrieve(url, parameter_file)

# Load parameter file.
with open(parameter_file, 'r') as ymlfile:
    cfg = yaml.safe_load(ymlfile)

# List all entries.
cfg.keys()

# Inspect directory tree.
yml_entry = 'dirtree'
yprint(cfg[yml_entry])  # in yaml format
pprint(cfg[yml_entry])  # as a dictionary


## Dataset

We provided a preprocessed data cutout in the Imaris v5.5 file format. which is an hdf5 file with 5 dimensions (a free [Imaris Viewer](https://imaris.oxinst.com/imaris-viewer) is available; and the file format can be inspected with [HDFview](https://www.hdfgroup.org/downloads/hdfview/) or with `h5ls` or `h5py`.

We download the file and name it according to the default STAPL-3D pipeline conventions.

In [None]:
ims_filepath = f'{dataset}_shading_stitching.ims'  # f'{dataset}_shading_stitching_biasfield.ims'

# Download the ims-file.
if not os.path.exists(ims_filepath):
    url = 'https://surfdrive.surf.nl/files/index.php/s/NxWhUWuLQBHPMGV/download'
    urllib.request.urlretrieve(url, ims_filepath)


We use the STAPL-3D Image class to load this file and inspect it's properties. We'll also save the dimensions, the Z-dimension and the number of channels in convenience variables `dims`, `Z` and `C`.


In [None]:
# Print image properties.
from stapl3d import Image
image_in = ims_filepath
im = Image(image_in)
im.load(load_data=False)
props = im.get_props()
im.close()
pprint(props)

# Convinience variables.
dims = im.dims
Z = im.dims[im.axlab.index('z')]
C = im.dims[im.axlab.index('c')]


In processing the full dataset, this cutout of **106 x 1408 x 1408 x 8** would equate to a single datablock, but for this demo we will further subdivide this block to demonstrate the pipeline.

## Parallelization

We have specified the shape of the processing blocks in the parameter file. Usually we would opt for a blocksize of ~100-200 million voxels; now we chose a blocksize in xy of 176 for 64 blocks of ~6M voxels. We keep the margin similar to what we set for big datasets as reducing it may hinder adequate analysis.

In [None]:
yprint(cfg['blocks']['blockinfo'])


The full anatomy of the blocked processing can now be loaded through the blocker object.


In [None]:
from stapl3d import blocks
block3r = blocks.Block3r(image_in, parameter_file, prefix=dataset)
print(block3r)


In initializing the *block3r* object, the sizes for the zyxct-dimensions were read from the input data and the dimensions that were specified in the configuration file for blocksize were substituted to determine the 5D-blocksize. 

For segmentation, we use a weighted sum of the membrane channels (ch3, ch5, ch6, ch7). The weights [0.5, 0.5, 1.0, 1.0] work well for this data. We set the name and internal .h5 path to 'memb/mean'.
We have specified this in the parameter file HFK16w.yml:

In [None]:
yprint(cfg['splitter']['split']['volumes'])


The above indicates that, in addition to the membrane sum, we generate a mean over all channels (used for generating masks). Importantly, we specify that we want to output channel 0 (DAPI), because we will use it to create a nuclear mask. 

Now we are ready to call the function that computes the membrane mean, and splits the data into blocks at the same time.

In [None]:
splitt3r = blocks.Splitt3r(image_in, parameter_file, prefix=dataset)
splitt3r.run()


Datablocks are written to the *HFK16w/blocks/* directory and are postfixed by the numeric ID of the block HFK16w/blocks/HFK16w_**B{b:05}**.h5.

These are some of the files that were generated:

In [None]:
from glob import glob

filelist = glob(os.path.join(os.path.abspath('.'), 'blocks', f'{dataset}_*.h5'))
filelist.sort()

f'Number of blocks: {len(filelist)}'
filelist[:5]


The resulting hdf5 files have three datasets named according to 'outputvolumes' entries in the cfg['splitter']['split'] specification, i.e. they have the following internal file structure:
- <...>.h5/mean
- <...>.h5/chan/ch00
- <...>.h5/memb/mean

It can be inspected and listed with the help of h5py:

In [None]:
import h5py

def extract_node_names(name, node):
    if isinstance(node, h5py.Dataset):
        nodes.append(name)
    return None

nodes = []
with h5py.File(filelist[0], 'r') as f:
    f.visititems(extract_node_names)
    pprint({'dataset names': nodes})
    idx = 0
    print(f'dataset {nodes[idx]} properties: ', f[nodes[idx]])
    print(f'dataset {nodes[idx]} resolution: ', f[nodes[idx]].attrs['element_size_um'])
    print(f'dataset {nodes[idx]} axes labels: ', f[nodes[idx]].attrs['DIMENSION_LABELS'])


Naturally, we also need to visually inspect the resulting averaged volumes. We can use the napari viewer method provided in the *splitt3r* object. We limit to the first 42 blocks and pick the mean membrane channel for demonstration. 

In [None]:
idxs = list(range(42))  # block indices
images = ['memb/mean']  # 'chan/ch00'
viewer_settings = {
    'title': 'STAPL3D splitt3r demo',
    'crosshairs': [int(splitt3r.fullsize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clim': [0, 5000],
}

splitt3r.view(input=idxs, images=images, settings=viewer_settings)


To get a grip on how the dataset is layed out in blocks, we can alternate the colormaps of the blocks.

In [None]:
# Alternate colormaps.
cmaps = ['cyan', 'magenta', 'yellow']
for i, lay in enumerate(splitt3r.viewer.layers):
    lay.colormap = cmaps[i % len(cmaps)]
lay.colormap = 'gray'


## Membrane enhancement

Before segmentation, we perform membrane enhancement.

For the demo we do not want to be dependent on the third-party [ACME](https://wiki.med.harvard.edu/SysBio/Megason/ACME) software and provide the output that otherwise results from the ACME procedure. We split it into blocks, and write it as separate datasets in the same files as the channel data.

Alternatively, if you have ACME installed and want to run it, set an `ACME` path environment variable or point `ACMEdir` to the directory with the binaries.

In [None]:
ims_filepath = f'{dataset}_shading_stitching.ims'  # f'{dataset}_shading_stitching_biasfield.ims'
image_in = ims_filepath

max_workers = 5  # NB: ACME is memory-intensive

from stapl3d.segmentation import enhance
enhanc3r = enhance.Enhanc3r(image_in, parameter_file, prefix=dataset, max_workers=max_workers)
enhanc3r.ACMEdir = os.environ.get('ACME')  # 'C:\\Users\\i.research_pc\\workspace\\ACME\\bin'  # 

if enhanc3r.ACMEdir:

    # Perform membrane enhancement.
    enhanc3r.run()

else:

    # Download precomputed membrane enhancement.
    acme_filepath = f'{dataset}_shading_stitching_ACME.h5'
    if not os.path.exists(acme_filepath):
        url = 'https://surfdrive.surf.nl/files/index.php/s/oQcxIocFBkaXwJe/download'
        urllib.request.urlretrieve(url, acme_filepath)

    # Split into blocks
    from stapl3d import blocks
    for ids in ['memb/preprocess', 'memb/planarity']:
        im_in = f'{acme_filepath}/{ids}'
        splitt3r = blocks.Splitt3r(im_in, parameter_file, prefix=dataset, step_id='')
        splitt3r.inputpaths['split']['data'] = im_in
        splitt3r.output_ND = ids
        splitt3r.volumes = {}  # FIXME: splitter entry is read from parameter_file despite <, step_id=''>
        splitt3r.run()


We vizualize the membrane-enhanced volume with napari.

In [None]:
# Initialize viewer.
idxs = list(range(42))  # block indices
images = ['memb/planarity']
viewer_settings = {
    'title': 'STAPL3D enhanc3r demo',
    'crosshairs': [int(enhanc3r.blocksize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clim': [0, 0.05],
}

enhanc3r.view(input=idxs, images=images, settings=viewer_settings)


## Segmentation

The segmentation is parallelized over the blocks we just created. Each of the 64 files is processed seperately.
The segmentation routine is associated with a fair amount of steps and parameters. This list all the parameters specified in the yml-file.

In [None]:
yprint(cfg['segmentation']['estimate'])  # TODO: need to preserve order of print


The blocks will processed in order according to the steps defined in the parameter file. Operations are listed below and step names have to be prefixed with these keywords.
- *prep*: filtering of volumes
- *mask*: compartment mask creation
- *combine*: mask combination
- *seed*: seed generation
- *segment*: watershed segmentation
- *filter*: size filtering and label masking

For your own data, it is advised to start with tuning the following parameters to optimize segmentation:
- mask_memb : threshold
- mask_nucl : sauvola : window_size
- mask_nucl : sauvola : threshold
- mask_nucl : sauvola : absmin
- seeds : peaks : window
- seeds : peaks : window
- segment : watershed : compactness
- prep_memb : filter : sigma
  

Next we calculate the segments for each block. Segmentation time of single block is in the minutes-range. The 106 x 240 x 240 blocksize (including the margin) will take ~1GB of memory per process. Please set the number of processes so that you will stay within RAM. `max_workers = 8` would be a fairly safe bet for modern systems; `max_workers = 0` results in using all available processors.

In [None]:
from stapl3d.segmentation import segment

max_workers = 0

segment3r = segment.Segment3r(image_in, parameter_file, prefix=dataset, max_workers=max_workers)
segment3r.run()


Report pages (pdf) have been written to the *HFK16w/blocks/* directory. Let's look at one of them:

In [None]:
block_idx = 1
# Get the outputpaths of the 'estimate' method for a block.
_, opaths = segment3r.fill_paths('estimate', reps={'b': block_idx})
# (Re)generate the report from the data and plot inline.
segment3r.report(outputpath=None, ioff=False, outputs=opaths)


From top to bottom, images are show for:
 - the smoothed DAPI and mean membrane channels
 - the nuclear mask and the membrane mask
 - the combined mask with detected peaks and overlaid on the distance transform image
 - the first and the final watershed results


We use the 'labels' argument to visualize masks and labels in napari. First, this overlays the membrane mask with the planarity volume for the top row of blocks.

In [None]:
idxs = list(range(8))  # block indices
images = ['memb/planarity']
labels = ['nucl/mask']

viewer_settings = {
    'title': 'STAPL3D segment3r demo',
    'crosshairs': [int(segment3r.blocksize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clim': {'memb/planarity': [0, 0.05]},
    'opacity': {'nucl/mask': 0.5},
}

segment3r.view(input=idxs, images=images, labels=labels, settings=viewer_settings)


Finally, we look at the extracted segments for the block we view as a pdf report.

In [None]:
images = ['nucl/prep']
labels = ['segm/labels']

viewer_settings = {
    'title': 'STAPL3D segment3r demo',
    'crosshairs': [int(segment3r.blocksize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clim': {'nucl/prep': [0, 20000]},
    'opacity': {'segm/labels': 0.8},
}

segment3r.view(input=block_idx, images=images, labels=labels, settings=viewer_settings)


# zipping

Having parallelized the segmentation process for increased analysis speed and reduced memory footprint, the need arises to reassemble the blocks into a final combined segmentation volume without seams at the block boundaries. These seams are a consequence of trivial parallelization in processing the individual blocks (i.e. without communication between the processes). They manifest through partial cells lying on the block boundaries that have been assigned different labels in different blocks. Importantly, these doubly segmented cells may not perfectly match up over the boundary.

This can be demonstrated by loading the segmentation of the top two rows of blocks:

In [None]:
segment3r.view(input=list(range(16)), images=images, labels=labels, settings=viewer_settings)


These block-boundary-segments need to be resegmented in order to complete the accurate segmentation of the full dataset. We refer to this correct reassembly of the datablocks as ‘zipping’. In short, it consists of identifying the segments lying on the boundaries, removing them, and resegmenting that space. We aimed to design the procedure such that it requires minimal computational resources and expertise (fast, with a low memory footprint, and without the need for communication between processes).

There is a one-liner for computing all the steps in the zip:
```
zipp3r.run()
```
which combines these three steps:
```
zipp3r.relabel()
zipp3r.copyblocks()
zipp3r.estimate()
```

For this demo, we will be much more verbose to illustrate the zipping process. 


We first perform a sequential relabeling of all the blocks to make each label unique.
We copy the relabeled blocks to new datasets in the same file for writing the zip-results in-place.

We define a convenience function that merges datablocks into a single volume and returns a single z-plane for display.

In [None]:
from stapl3d import blocks

import matplotlib as mpl
import matplotlib.pyplot as plt

volumes = {
    'memb/mean': {'format': 'h5', 'suffix': None},
    'segm/labels': {'format': 'h5', 'suffix': None, 'is_labelimage': True},
    'segm/labels_zip': {'format': 'h5', 'suffix': None, 'is_labelimage': True},
    'segm/labels_zipmask': {'format': 'h5', 'suffix': None},
}
merg3r = blocks.Merg3r(image_in, parameter_file, prefix=dataset)
merg3r._volumes = [{ids: vol} for ids, vol in volumes.items()]
merg3r._init_paths_merger()


def merge_and_slice_dset(merg3r, ids, slc=20):
    """Merge volume and return sliced data."""

    # Run the block merge.
    merg3r.run()

    # Get a slice of the merged data.
    _, opaths = merg3r.fill_paths('merge')
    im = Image(opaths[ids])
    im.load()
    im.slices[0] = slice(slc, slc + 1, 1)
    data = im.slice_dataset()
    im.close()

    return data


For plotting labels, we define a label shuffling function.

In [None]:
import numpy as np
from skimage.segmentation import relabel_sequential
from random import shuffle
from skimage.color import label2rgb

def shuffle_labels(labels):
    """Shuffle labels in a volume."""

    labels = relabel_sequential(labels)[0]
    ulabels = np.unique(labels[:])[1:]
    relabeled = [l for l in range(0, len(ulabels))]
    shuffle(relabeled)
    labels = np.array([0] + relabeled)[labels]

    return labels


Let's check with the membrane mean blocks. This should output an image of 1408 x 1408.

In [None]:
ids = 'memb/mean'
merg3r._volumes = [{ids: volumes[ids]}]
merg3r._init_paths_merger()
img = merge_and_slice_dset(merg3r, ids)


In [None]:
f = plt.figure(figsize=(8, 8))
plt.imshow(img, cmap='gray', vmax=5000)
plt.show()


In the same way, we can show the labels with the seams before zipping.

In [None]:
ids = 'segm/labels_zip'
merg3r._volumes = [{ids: volumes[ids]}]
merg3r._init_paths_merger()
labels = shuffle_labels(merge_and_slice_dset(merg3r, ids))


In [None]:
f = plt.figure(figsize=(8, 8))
plt.imshow(label2rgb(labels))
plt.show()


Next we set up the zipping estimation.

In [None]:
step_id = 'estimate'

# write maxlabels to file
outputs = zipp3r._prep_paths(zipp3r.outputpaths[step_id])
maxlabelfile = outputs['maxlabelfile']
kwargs = {}
arglist = zipp3r._prep_step(step_id, kwargs)
filepaths = zipp3r._get_filepaths(arglist)
zipping.get_maxlabels_from_attribute(
    filepaths,
    zipp3r.ids_labels,
    maxlabelfile,
)


In the zipping procedure, we employ an order such that no blocks are handled concurrently. First, blocks with overlap in the Y-dimension are processed (odd and even zip-lines separately); then X-ziplines; then the corners where four datablocks overlap are resegmented. For demo purpose, we keep track of the output for each step and store it in `imgs`.

In [None]:
import numpy as np

imgs = []

# Initialize arguments
args = [filepaths, 0, [-1, -1, -1], maxlabelfile]

# Resegment zip-lines in 4 groups:
# horizontal/even, horizontal/odd, vertical/even, vertical/odd zip-lines.
for axis, n_seams in zip([1, 2], zipp3r.seamgrid.shape):
    n_proc = min(zipp3r._n_workers, int(np.ceil(n_seams / 2)))
    for offset in [0, 1]:
        zipp3r.compute_zip_step(
            args, axis=axis,
            starts=[offset, 0], stops=[n_seams, 1], steps=[2, 2],
            n_proc=n_proc,
        )
        zipping.get_maxlabels_from_attribute(
            filepaths,
            zipp3r.ods_labels,
            maxlabelfile,
        )
        imgs.append(merge_and_slice_dset(merg3r, 'segm/labels_zip'))

# Resegment zip-quads in 4 groups:
# even/even, even/odd, odd/even, odd/odd zip-line intersections
for start_y in [0, 1]:
    for start_x in [0, 1]:
        stops = list(zipp3r.seamgrid.shape)
        zipp3r.compute_zip_step(
            args, axis=0,
            starts=[start_y, start_x], stops=stops, steps=[2, 2],
            n_proc=zipp3r._n_workers,
        )
        zipping.get_maxlabels_from_attribute(
            filepaths,
            zipp3r.ods_labels,
            maxlabelfile,
        )
        imgs.append(merge_and_slice_dset(merg3r, 'segm/labels_zip'))


In [None]:
f, axs = plt.subplots(2, 4, figsize=(24, 12))
for img, ax in zip(imgs, axs.flat):
    ax.imshow(img)
plt.show()


Newly processed zip-lines are assigned high labels indicated in yellow of the viridis colormap, nicely demonstrating the zipping process.

<!---The zip-lines still have seams in the places where they intersect. Next we process zip-quads, in which the segments on these intersections are resegmented to finish the zip.-->

Now, we compare the labels before and after the zip:

In [None]:
ids = 'segm/labels_zip'
merg3r._volumes = [{ids: volumes[ids]}]
merg3r._init_paths_merger()
labels_zipped = shuffle_labels(merge_and_slice_dset(merg3r, ids))


In [None]:
f, axs = plt.subplots(1, 2, figsize=(16, 32))
for img, ax in zip([labels, labels_zipped], axs.flat):
    ax.imshow(label2rgb(shuffle_labels(img)))
plt.show()


View the zip result with napari:

In [None]:
idss = ['segm/labels', 'segm/labels_zip']

merg3r = blocks.Merg3r(image_in, parameter_file, prefix=dataset)
merg3r._volumes = [{ids: volumes[ids]} for ids in idss]
merg3r._init_paths_merger()
merg3r.run()

viewer_settings = {
    'title': 'STAPL3D merg3r demo',
    'crosshairs': [int(merg3r.blocksize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clims': [0, 1],
    'opacity': 1,
}

filepath = merg3r.outputpaths['postprocess']['aggregate']
merg3r.view(input=filepath, images=[], labels=idss, settings=viewer_settings)

# NOTE: some new seams are created on the margins because the blocks of this demo are too small


## Compartmental segmentation

In STAPL-3D, we use rich multidimensional data to obtain a robust segmentation. We can also use the information we have to perform subcellular segmentation. Here, we split segments in nucleus and membrane subsegments such that we can specifically extract intensities from the appropriate voxels for the type of staining (nuclear or membranal). In addition, the subsegmentation opens up possibilities for defining compound features that inform on internal cell structure.


In [None]:
from stapl3d.segmentation import segment

subsegment3r = segment.Subsegment3r(image_in, parameter_file, prefix=dataset)
subsegment3r.run()


In [None]:
vols = [f'segm/labels_{vol}' for vol in ['full', 'nucl', 'memb']]
merg3r = blocks.Merg3r(image_in, parameter_file, prefix=dataset)
merg3r._volumes = [{vol: {'format': 'h5', 'suffix': None, 'is_labelimage': True}} for vol in vols]
merg3r._init_paths_merger()
merg3r.run()

viewer_settings = {
    'title': 'STAPL3D merg3r demo',
    'crosshairs': [int(merg3r.blocksize[dim] / 2) for dim in 'zyx'],
    'axes_visible': False,
    'clims': [0, 1],
    'opacity': 1,
}

filepath = merg3r.outputpaths['postprocess']['aggregate']
merg3r.view(input=filepath, images=[], labels=vols, settings=viewer_settings)
