# STAPL-3D segmentation demo

This notebook demonstrates the core components of the STAPL-3D segmentation pipeline: **blockwise segmentation** and **zipping**.

If you did not follow the STAPL-3D README: please find STAPL-3D and the installation instructions [here](https://github.com/RiosGroup/STAPL3D) before doing this demo.

Because STAPL-3D is all about big datafiles, we provide small cutouts and precomputed summary data. Please, download [HFK16w.zip](https://surfdrive.surf.nl/files/index.php/s/Q9wRT5cyKGERxI5) (~6GB). Note that the Preprocessing and Segmentation demo's use the same zip-file.

First, define where you have put the data. Please change *datadir* to point to the *HFK16w* directory that you have unzipped.

In [None]:
import os

datadir = './HFK16w'
dataset = 'HFK16w'
filestem = os.path.join(datadir, dataset)


## Parallelization

We provided a data cutout in the Imaris v5.5 file format, which is an hdf5 file with 5 dimensions. In processing the full dataset, this would equate to a single datablock; for this demo we will further subdivide this blocks to demonstrate the pipeline.

We use the STAPL-3D Image class to load this file and inspect it's properties. We'll also save the dimensions, the Z-dimension and the number of channels in convenience variables `dims`, `Z` and `C`.


In [None]:
from stapl3d import Image

image_in = '{}_bfc_block.ims'.format(filestem)

im = Image(image_in)
im.load(load_data=False)

props = im.get_props()

im.close()

dims = im.dims
Z = im.dims[im.axlab.index('z')]
C = im.dims[im.axlab.index('c')]

props


For segmentation, we use a weighted sum of the membrane channels (ch3, ch5, ch6, ch7). The weights [0.5, 0.5, 1.0, 1.0] work well for this data.
We have specified this in the parameter file HFK16w.yml:

In [None]:
import yaml

parameter_file = '{}.yml'.format(filestem)
with open(parameter_file, 'r') as ymlfile:
    cfg = yaml.safe_load(ymlfile)

cfg['channels']


The above indicates that, in addition to the membrane sum, we generate a nuclear channel mean as well as a mean over all channels (used for generating masks). Importantly, we specify that we want to output channel 0 (DAPI), because we will use it to create a nuclear mask.

Next, we specify the shape of the processing blocks. Usually we would opt for a blocksize of ~100-200 million voxels; now we chose a blocksize in *xy* of 176 for 64 blocks of ~6M voxels. We keep the margin similar to what we set for big datasets as reducing it may hinder adequate analysis.

In [None]:
bs = 176  # blocksize
bm = 64  # blockmargin

blocksize = [Z, bs, bs, C, 1]
blockmargin = [0, bm, bm, 0, 0]

blockdir = os.path.join(datadir, 'blocks_{:04d}'.format(bs))
block_prefix = os.path.join(blockdir, '{}_bfc_block'.format(dataset))
os.makedirs(blockdir, exist_ok=True)

'Processing data in blocks of {} voxels with a margin of {} voxels'.format(blocksize, blockmargin)


Now we are ready to call the function that computes the membrane mean, and splits the data into blocks at the same time. Datablocks are written to the *HFK16w/blocks/* directory and are postfixed by the voxel coordinates of the original datafile HFK16w/blocks/HFK16w_**x-X_y_Y_z-Z**.h5. 

In [None]:
from stapl3d.channels import process_channels

process_channels(
    image_in,
    parameter_file,
    blocksize,
    blockmargin,
    outputprefix=block_prefix,
    )


These are some of the files that were generated:

In [None]:
from glob import glob

ipf = ''
filelist = glob(os.path.join(blockdir, '{}_*{}.h5'.format(dataset, ipf)))
filelist.sort()
len(filelist), filelist[:5]


We can look at the groups and dataset (internal h5 file structure) with h5py.

In [None]:
import h5py

with h5py.File(filelist[20],'r') as f:
    f.visit(print)


The resulting hdf5 files have the following internal file structure:
    - .h5/mean
    - .h5/chan/ch00
    - .h5/memb/mean
    - .h5/nucl/mean

## Membrane enhancement

Before segmentation, we perform membrane enhancement. For the demo we do not want to be dependent on the third-party ACME software and provide the output that otherwise results from the ACME procedure. We split it into blocks, and write it as separate datasets in the same files as the channel data.

In [None]:
from stapl3d.channels import splitblocks

for ids in ['memb/preprocess', 'memb/planarity']:
    image_in = '{}_bfc_block_ACME.h5/{}'.format(filestem, ids)
    output_template = '{}_{}.h5/{}'.format(block_prefix, '{}', ids)
    splitblocks(image_in, [106, bs, bs], [0, bm, bm], output_template)


# from stapl3d.channels import h5_nii_convert
# filestem = 'HFK16w_bfc_block_00336-00664_00936-01264_00000-00106'
# image_in = os.path.join(blockdir, '{}.h5/memb/mean'.format(filestem))
# image_out = os.path.join(blockdir, '{}_memb-mean.nii.gz'.format(filestem))
# h5_nii_convert(image_in, image_out)

# ... ACME ...

# for vol in ['preprocess', 'planarity']:
#     image_in = os.path.join(blockdir, '{}_memb-{}.nii.gz'.format(filestem, vol))
#     image_out = os.path.join(blockdir, '{}_foo.h5/memb/{}'.format(filestem, vol))
#     h5_nii_convert(image_in, image_out)


## Segmentation

The segmentation is parallelized over the blocks we just created. Each of the 64 files is processed seperately.

In [None]:
from glob import glob

ipf = ''
filepat = '{}_*{}.h5'.format(dataset, ipf)
filelist = glob(os.path.join(blockdir, filepat))
filelist.sort()

len(filelist), filelist[:5]


The segmentation routine is associated with a fair amount of parameters. This list all the parameters specified in the yml-file.

In [None]:
import yaml

parameter_file = '{}.yml'.format(filestem)
with open(parameter_file, 'r') as ymlfile:
    cfg = yaml.safe_load(ymlfile)

cfg['segmentation']


A few parameter of particular note:
- input volumes:
        'ids_memb_mask': 'memb/planarity'
        'ids_memb_chan': 'memb/mean'
        'ids_nucl_chan': 'chan/ch00'
        'ids_dset_mean': 'mean'

The following parameters can be changed to optimize segmentation or use parameters from automated fine tuning:
- membrane mask:
    - 'planarity_thr': 0.0005
- nuclei mask:
    - 'sauvola_window_size': [19, 75, 75]
    - 'dapi_thr': 5000
    - 'dapi_absmin': 1000
- peak detection:
    - 'peaks_size': [11, 19, 19]
    - 'compactness': 0.8
- watershed:
    - 'memb_sigma': 3.0
  

Next we calculate the segments for each block. The 106 x 240 x 240 blocksize (including the margin) will take ~1GB of memory per process. Please set the number of processes so that you will stay within RAM. n_proc = 8 would be a fairly safe bet for modern systems. Segmentation time of single block is in the minutes-range.

In [None]:
n_proc = 16


We use Python's multiprocessing for parallel processing if calling the function from within a Python interpreter. A list of argument-tuples is generated to serve as input, so let's look at the arguments of the segmentation function:


In [None]:
import multiprocessing
from stapl3d.segmentation.segment import extract_segments
extract_segments

We specify 6 arguments per job, where specific arguments point to the datasets in a particular .h5 datablock-file.

In [None]:
arglist = []
for datafile in filelist:
    args = [
        '{}/{}/{}'.format(datafile, 'memb', 'planarity'),
        '{}/{}/{}'.format(datafile, 'memb', 'mean'),
        '{}/{}/{}'.format(datafile, 'chan', 'ch00'),
        '{}/{}'.format(datafile, 'mean'),
        parameter_file,
        datafile.replace('.h5', ''),
        True,
        ]
    arglist.append(tuple(args))

with multiprocessing.Pool(processes=n_proc) as pool:
    pool.starmap(extract_segments, arglist)


Report pages (pdf) have been written to the *HFK16w/blocks/* directory. Let's look at one of them:

In [None]:
from stapl3d.segmentation.segment import generate_report

image_in = '{}/memb/mean'.format(filelist[20])
generate_report(image_in, ioff=False)


From left to right, images are show for:
 - the DAPI channel and the membrane mean
 - the nuclear mask and the membrane mask
 - the combined mask with detected peaks and overlaid on the distance transform image
 - the first and the final watershed results

# zipping

Having parallelized the segmentation process for increased analysis speed and reduced memory footprint, the need arises to reassemble the blocks into a final combined segmentation volume without seams at the block boundaries. These seams are a consequence of trivial parallelization in processing the individual blocks (i.e. without communication between the processes). They manifest through partial cells lying on the block boundaries that have been assigned different labels in different blocks. Importantly, these doubly segmented cells may not perfectly match up over the boundary. These block-boundary-segments need to be resegmented in order to complete the accurate segmentation of the full dataset. We refer to this correct reassembly of the datablocks as ‘zipping’. In short, it consists of identifying the segments lying on the boundaries, removing them, and resegmenting that space. We aimed to design the procedure such that it requires minimal computational resources and expertise (fast, with a low memory footprint, and without the need for communication between processes).

### relabel
We first perform a sequential relabeling of all the blocks to make each label unique.

In [None]:
from stapl3d.segmentation.zipping import get_maxlabels_from_attribute
from stapl3d.segmentation.zipping import relabel_parallel

grp = 'segm'
ids = 'labels_memb_del'
postfix = 'relabeled'

# Write the maximum label of each block to a file.
filename = '{}_maxlabels_{}.txt'.format(dataset, ids)
maxlabelfile = os.path.join(blockdir, filename)
maxlabels = get_maxlabels_from_attribute(filelist, '{}/{}'.format(grp, ids), maxlabelfile)

# The relabel_parallel function has four arguments: 
# inputdataset, block index, outputfile and outputpostfix
arglist = []
for block_idx, datafile in enumerate(filelist):
    args = [
        '{}/{}/{}'.format(datafile, grp, ids),
        block_idx,
        maxlabelfile,
        postfix,
    ]
    arglist.append(tuple(args))

with multiprocessing.Pool(processes=n_proc) as pool:
    pool.starmap(relabel_parallel, arglist)


### copy blocks
We copy the relabeled blocks to new datasets in the same file for in-place zipping.

In [None]:
from stapl3d.segmentation.zipping import copy_blocks_parallel

grp = 'segm'
ids = 'labels_memb_del_relabeled'
postfix = 'fix'

# Write the maximum label of each block to a file.
filename = '{}_maxlabels_{}.txt'.format(dataset, ids)
maxlabelfile = os.path.join(blockdir, filename)
maxlabels = get_maxlabels_from_attribute(filelist, '{}/{}'.format(grp, ids), maxlabelfile)

# The copy_blocks_parallel function has three arguments: 
# inputdataset, block index, outputpostfix
arglist = []
for block_idx, datafile in enumerate(filelist):
    args = [
        '{}/{}/{}'.format(datafile, grp, ids),
        block_idx,
        postfix,
    ]
    arglist.append(tuple(args))

with multiprocessing.Pool(processes=n_proc) as pool:
    pool.starmap(copy_blocks_parallel, arglist)


# Write a maxlabelfile in which the maxlabels are tracked during zipping.
pf = 'relabeled_fix'
maxlabelfile = os.path.join(blockdir, '{}_maxlabels_{}.txt'.format(dataset, pf))
maxlabels = get_maxlabels_from_attribute(filelist, 'segm/labels_memb_del_{}'.format(pf), maxlabelfile)
'maxlabs after copy {}'.format(maxlabels)


## zip
Next, we define the zipping parameters and functions. First, we set the number of processors and the block-layout


In [None]:
import numpy as np

n_proc_max = n_proc

# Set the number of seams in the data.
n_seams_yx = [7, 7]  # we have 8 x 8 blocks in the HFK16w dataset with bs=176

seams = list(range(np.prod(n_seams_yx)))
seamgrid = np.reshape(seams, n_seams_yx)
seamgrid


Now, the zipping parameters are defined as well as functions to turn the parameters into arguments for the zipping-steps.

In [None]:
from stapl3d.segmentation.zipping import resegment_block_boundaries

# Arguments to `resegment_block_boundaries`
images_in = ['{}/{}/{}_{}'.format(datafile, 'segm', 'labels_memb_del', pf)
             for datafile in filelist]
blocksize=[Z, bs, bs]
blockmargin=[0, bm, bm]
axis=0
seamnumbers=[-1, -1, -1]
mask_dataset=''
relabel=False
maxlabel=maxlabelfile
in_place=True
outputstem=os.path.join(blockdir, dataset)
save_steps=False
args = [
    images_in,
    blocksize,
    blockmargin,
    axis,
    seamnumbers,
    mask_dataset,
    relabel,
    maxlabel,
    in_place,
    outputstem,
    save_steps,
]


def get_arglist(args, axis, starts, stops, steps):
    """Replace the `axis` and `seamnumbers` arguments
    with values specific for sequential zip-steps.
    
    axis = 0: zip-quads
    axis = 1: zip-lines over Y
    axis = 2: zip-lines over X
    seamnumbers: start-stop-step triplets (with step=2)
    """

    arglist = []
    for seam_y in range(starts[0], stops[0], steps[0]):
        for seam_x in range(starts[1], stops[1], steps[1]):
            seamnumbers = [-1, seam_y, seam_x]
            args[3] = axis
            if axis == 0:
                args[4] = [seamnumbers[d] if d != axis else -1 for d in [0, 1, 2]]
            else:
                args[4] = [seam_y if d == axis else -1 for d in [0, 1, 2]]
            arglist.append(tuple(args))

    return arglist


def compute_zip_step(args, axis, seamgrid, starts, stops, steps, n_proc):
    """Compute the zip-step."""

    arglist = get_arglist(args, axis, starts, stops, steps)
    print('submitting {:3d} jobs over {:3d} processes'.format(len(arglist), n_proc))

    with multiprocessing.Pool(processes=n_proc) as pool:
        pool.starmap(resegment_block_boundaries, arglist)


We define a convenience function that merges datablocks into a single volume and returns a single z-plane for display.

In [None]:
from stapl3d.mergeblocks import mergeblocks

import matplotlib as mpl
import matplotlib.pyplot as plt

def merge_and_slice_dset(filelist, ids, dims, bs, bm, slc=20):

    # Merge the datablocks.
    images_in=['{}/{}'.format(datafile, ids)
               for datafile in filelist]
    filename = '{}.h5/{}'.format(dataset, ids)
    outputpath=os.path.join(datadir, filename)

    mergeblocks(
        images_in=images_in,
        outputpath=outputpath,
        blocksize=[dims[0], bs, bs],
        blockmargin=[0, bm, bm],
        fullsize=dims[:3],
    )

    # Get a slice of the merged data.
    im = Image(outputpath)
    im.load()
    im.slices[0] = slice(slc, slc + 1, 1)
    data = im.slice_dataset()
    im.close()

    return data


Let's check with the membrane mean blocks. This should output an image of 1408 x 1408.

In [None]:
ids = 'memb/mean'
img = merge_and_slice_dset(filelist, ids, dims, bs, bm)

plt.imshow(img, cmap='gray', vmax=5000)
plt.show()


In the zipping procedure, we employ an order such that no blocks are handled concurrently. First, blocks with overlap in the Y-dimension are processed (odd and even zip-lines separately); then X-ziplines; then the corners where four datablocks overlap are resegmented. For demo purpose, we keep track of the output for each step and store it in `imgs`.

In [None]:
ids = 'segm/labels_memb_del_relabeled_fix'

imgs = []

for axis, n_seams in zip([1, 2], n_seams_yx):

    n_proc = min(n_proc_max, int(np.ceil(n_seams / 2)))

    for offset in [0, 1]:

        # do the zip-step
        compute_zip_step(
            args, axis, seamgrid,
            starts=[offset, 0], stops=[n_seams, 1], steps=[2, 2],
            n_proc=n_proc,
        )

        # update maxlabels
        maxlabels = get_maxlabels_from_attribute(filelist, ids, maxlabelfile)

        # keep image for display
        imgs.append(merge_and_slice_dset(filelist, ids, dims, bs, bm))

f, axs = plt.subplots(1, 4, figsize=(24, 24))
for img, ax in zip(imgs, axs):
    ax.imshow(img)
plt.show()


Newly processed zip-lines are assigned high labels indicated in yellow of the viridis colormap, nicely demonstrating the zipping process.

The zip-lines still have seams in the places where they intersect. Next we process zip-quads, in which the segments on these intersections are resegmented to finish the zip. 

In [None]:
# resegment zip-quads in 4 groups even/even, even/odd, odd/even, odd/odd zip-line intersections
ids = 'segm/labels_memb_del_relabeled_fix'

imgs = []

for start_y in [0, 1]:

    for start_x in [0, 1]:

        # do the zip-step
        compute_zip_step(
            args, axis=0, seamgrid=seamgrid,
            starts=[start_y, start_x], stops=n_seams_yx, steps=[2, 2],
            n_proc=n_proc,
        )

        # update maxlabels
        maxlabels = get_maxlabels_from_attribute(filelist, ids, maxlabelfile)

        # keep image for display
        imgs.append(merge_and_slice_dset(filelist, ids, dims, bs, bm))

f, axs = plt.subplots(1, 4, figsize=(24, 24))
for img, ax in zip(imgs, axs):
    ax.imshow(img)
plt.show()


To visualize the segments in the more common random colors, we relabel, shuffle and plot.

In [None]:
from skimage.segmentation import relabel_sequential
from skimage.color import label2rgb
from random import shuffle

img = merge_and_slice_dset(filelist, ids, dims, bs, bm)

img = relabel_sequential(img)[0]

ulabels = np.unique(img[:])[1:]
relabeled = [l for l in range(0, len(ulabels))]
shuffle(relabeled)

img = np.array([0] + relabeled)[img]

f = plt.figure(figsize=(12, 12))
plt.imshow(label2rgb(img))
plt.show()


In STAPL-3D, we use rich multidimensional data to obtain a robust segmentation. We can also use the information we have to perform subcellular segmentation. Here, we split segments in nucleus and membrane subsegments such that we can specifically extract intensities from the appropriate voxels for the type of staining (nuclear or membranal). In addition, the subsegmentation opens up possibilities for defining compound features that inform on internal cell structure.


In [None]:
ids = 'segm/labels_memb_del_relabeled_fix'
seg_path = os.path.join(datadir, '{}.h5/{}'.format(dataset, ids))

for ids1 in ['nucl/dapi_mask_sauvola', 'nucl/dapi_mask_absmin']:
    merge_and_slice_dset(filelist, ids1, dims, bs, bm)

from stapl3d.segmentation.segment import split_segments
split_segments(seg_path, outputstem=filestem)


Let's have a look at a corner of the section to visualize the subcellular compartments.

In [None]:
from stapl3d import LabelImage
from skimage.color import label2rgb

slc = 20

# get background images
ids_n = 'nucl/dapi_preprocess'
dapi = merge_and_slice_dset(filelist, ids_n, dims, bs, bm)
ids_m = 'memb/mean_smooth'
memb = merge_and_slice_dset(filelist, ids_m, dims, bs, bm)

f, axs = plt.subplots(1, 3, figsize=(24, 24))
segs = [
    'segm/labels_memb_del_relabeled_fix', 
    'segm/labels_memb_del_relabeled_fix_memb',
    'segm/labels_memb_del_relabeled_fix_nucl',
]
bgs = [memb, dapi, memb]

for ax, seg, bg in zip(axs, segs, bgs):
    seg_path = os.path.join(datadir, '{}.h5/{}'.format(dataset, seg))
    im = LabelImage(seg_path)
    im.load()
    im.slices[0] = slice(slc, slc + 1, 1)
    img = im.slice_dataset()
    im.close()

    img = img[:500,:500]
    bg = bg[:500,:500] * 5
    clabels = label2rgb(img, image=bg, alpha=1.0, bg_label=0)
    ax.imshow(clabels)

plt.show()
