## Initial setup

In [None]:
import collections
import dask
from dask.distributed import Client, LocalCluster
import itertools
import math
import ndsafir
import numpy as np
from pylibCZIrw import czi
import zarr

A `dask.config` parameter that needs to be set early.  It prevents dask
from eagerly loading extra tasks onto workers.  The default is 1.1; when
computing how many tasks a worker should have Dask uses
`ceil(worker-saturation * n_threads)`, so even with 1 thread Dask
will try to load 2 tasks onto a worker.

For many workloads this is reasonable, since the worker can overlap
disk access (say) with computation.  For us this is problematic as our
workloads do not sleep (so the other tasks get no CPU) and are very
memory hungry.

In [None]:
dask.config.set({"distributed.scheduler.worker-saturation": 1.0})

We create a cluster here for testing purposes.  When running in an HPC job
you do not want these lines.

In [None]:
cluster = LocalCluster(n_workers=1, threads_per_worker=1, memory_limit="16GiB")
client = cluster.get_client()
client

## Some utility functions and types.

Helper classes to hold things (indices, slices, ranges)
corresponding to T, C, Z, Y, X, or just T, Z, Y, X.

In [None]:
TZYX = collections.namedtuple("TZYX", ["T", "Z", "Y", "X"])
class TCZYX(collections.namedtuple("TCZYX", ["T", "C", "Z", "Y", "X"])):
    @property
    def TZYX(self):
        return TZYX(self.T, self.Z, self.Y, self.X)

A general method to determine whether a range
is wholly contained within another.
E.g.
```python
>>> r1 = range(0, 10)
>>> r2 = range(1, 3)
>>> range_contains(r2, r1)
True
>>> range_contains(r1, r2)
False
>>> r3 = range(4, 11)
>>> range_contains(r3, r1)
False
>>> r4 = range(1, 3, -1)
>>> range_contains(r4, r1)
True
```

In [None]:
def range_contains(r_inner: range, r_outer: range):
    """
    Determines whether r_inner is entirely present in r_outer.
    Both parameters must be ranges.
    """

    # Empty set is always contained.
    if len(r_inner) == 0:
        return True

    # Inner start must be there.
    if r_inner.start not in r_outer:
        return False
    if len(r_inner) == 1:
        return True

    # Actual last element of r_inner, must be there.
    if r_inner[-1] not in r_outer:
        return False
    if len(r_inner) == 2:
        return True

    return (r_inner.step % r_outer.step) == 0

Trims off any overlap from a block. This is delayed.

In [None]:
@dask.delayed
def trim(block, ch_index, chunk_array, overlaps):
    """
    Trim off the overlap data from the block.

    Params:
    block: np.ndarray The data to trim.
    ch_index: TZYX The position in the overall array, in chunk units.
    chunk_array: TZYX The size of the overall array, in chunk units.
    overlap: TZYX The size of the overlaps in each axis.

    Returns:
    np.ndarray: Trimmed data.
    """
    axes = []
    for i, c_i in enumerate(ch_index):
        lo = overlaps[i] if c_i > 0 else 0
        hi = -overlaps[i] if overlaps[i] > 0 and c_i < chunk_array[i]-1 else None
        axes.append(slice(lo, hi))
    print(axes)
    axes.insert(1, slice(0, None)) # C
    return block[*axes]

Save a block to a Zarr file. This works out where in the zarr file this chunk belongs
then slices it directly into the Zarr array.  The file is (re-)opened for each block.

In [None]:
@dask.delayed
def save_block_to_zarr(block, zarr_filename, chunk_index, chunk_sizes):
    """
    Save the block to output zarr array.

    Params:
    block: np.ndarray Block to save.
    zarr_filename: str Filename of zarr array.
    chunk_index: TZYX The position in the overall array, in chunk units.
    chunk_size: TZYX Chunk size.

    Returns:
    nothing.
    """

    # for now assume zarr_filename exists and is a suitable zarr array
    output_zarr = zarr.convenience.open(zarr_filename, mode="r+")
    section = list(map(lambda t: slice(t[0] * t[1], t[0] * t[1] + t[2]), zip(chunk_index, chunk_sizes, TCZYX(*block.shape).TZYX)))
    section.insert(1, slice(0, None))
    dask.distributed.print(f"Saving section {section}")
    output_zarr[*section] = block

The Zarr file might already contain chunks that we've computed. We want to avoid
re-computing them. This function compares the Dask chunk (`da_chunk`) index
to the `existing_chunks` set which contains all existing Zarr chunks.
We assume the Dask chunk size is an integral multiple of the Zarr
chunk size in all axes. (This is checked elsewhere).

In [None]:
def check_existing_chunks(da_chunk: TZYX, chunk_ratios, existing_chunks):
    """
    Check to see if all necessary chunks are present in existing_chunks
    (a set of tuples of TZYX).  It checks all zarr chunks that contribute
    to a dask chunk.  It stops early and uses a set for checking.

    The chunk ratios must be integral.

    da_chunk: TZYX is the dask chunk
    chunk_ratios: TZYX dask_chunk_size / zarr_chunk size for each axis
    existing_chunks: set(TZYX) set of existing chunks
    """
    ranges = []
    for da_ch, ch_r in zip(da_chunk, chunk_ratios):
        ranges.append(range(da_ch * ch_r, (da_ch+1) * ch_r))
    return all(
        map(
            lambda t: t in existing_chunks,
            itertools.product(*ranges)
        )
    )

Loads a single chunk spanning the given ranges from a CZI file. Also delayed.

In [None]:
@dask.delayed
def load_chunk(input_filepath, t_range, z_range, y_range, x_range, scene):
    """
    Load slices of the input file, as given by the ranges
    from scene `scene`.

    Params:
    input_filepath: filepath of the CZI file.
    t_range, etc: ranges detailing region of image to load.
    scene: CZI scene number

    The ranges should be range objects (e.g. `range(0, 10)`
    
    This uses the extremely clunky interface of pylibCZIrw.
    There might be cleaner alternatives in `bioio`.
    
    Returns a ndarray copy of the slice, in TCZYX order.
    """
    with czi.open_czi(input_filepath) as input_file:
        bbox = input_file.total_bounding_box
        C = bbox["C"][1] - bbox["C"][0]
        # The ROI can go beyond our scene without error.
        # We will check here and raise an error if we do so.
        bound_rect = input_file.scenes_bounding_rectangle.get(scene, None)
        if not bound_rect:
            bound_rect = input_file.total_bounding_rectangle
        X = range(bound_rect[0], bound_rect[0] + bound_rect[2])
        Y = range(bound_rect[1], bound_rect[1] + bound_rect[3])
        if not range_contains(y_range, Y) or not range_contains(x_range, X):
            raise ValueError("input range lies outside image")
        
        # Can query input_file.pixel_types but assume uint16
        data = np.empty(shape=(len(t_range), C, len(z_range), len(y_range), len(x_range)), dtype=np.uint16)
        roi = (x_range.start, y_range.start, len(x_range), len(y_range))
        ndarray_t = 0
        for t in t_range:
            ndarray_z = 0
            for z in z_range:
                # NB input_file zero indexed
                data[ndarray_t, :, ndarray_z, :, :] = np.moveaxis(
                    input_file.read(
                        plane={"T": t, "Z": z},
                        scene=scene,
                        roi=roi,
                    ),
                    2,
                    0
                )
                ndarray_z += 1
            ndarray_t += 1
    return data

A wrapper to call the ndsafir denoising code, with some set parameters. These could be made parameters of
`process_block` if desired. Delayed.

Note that is is important that if the code here is outside Python and is expected
to be long-running then you should make sure the Python GIL is released before
starting execution (e.g. for
[Pybind](https://pybind11.readthedocs.io/en/stable/advanced/misc.html)).
Failing to do so will usually result in Dask being unable to contact
the Worker to check its status.  It will then kill the Worker and reassign
the task (which will kill the next Worker, and so on).

For example in C++ this might be as simple as wrapping the long-running
function `filter.run` in a scope with a `gil_scoped_release` object.
```
  {
    py::gil_scoped_release release;
    filter.run(f, noise_std);
  }
```
Obviously the external code should not access any Python objects while
not holding the GIL.

In [None]:
@dask.delayed
def process_block(block):
    """
    Denoise a block, then return it
    """
    denoised = ndsafir.denoise(
        block,
        mode="poisson-gaussian",
        gains=[3.92],
        offsets=[-388],
        patch=[0, 0, 1, 3, 3],
        max_iter=4,
        pvalue=0.1,
        nthreads=12,
        axes="TCZYX",
    )
    return denoised

The `pylibCZIrw` interface is awful. The CZI file doesn't have to have
any scenes declared, in which case the `scenes_bounding_rectangle`
method will fail. This routine is here to return the shape
of the scene requested whether or not there's a scene.

In [None]:
def get_czi_scene_shape(filename, scene=0):
    """
    Return a TCZYX containing the sizes of the CZI file in
    those axes.
    """
    with czi.open_czi(filename) as input_file:
        bbox = input_file.total_bounding_box
        C = bbox["C"][1] - bbox["C"][0]
        T = bbox["T"][1] - bbox["T"][0]
        Z = bbox["Z"][1] - bbox["Z"][0]

        bound_rect = input_file.scenes_bounding_rectangle.get(scene, None)
        if not bound_rect:
            bound_rect = input_file.total_bounding_rectangle
        X_min = bound_rect.x
        width = bound_rect.w
        Y_min = bound_rect.y
        height = bound_rect.h
    return TCZYX(T, C, Z, height, width)

The main routine to load all the chunks from the CZI file. This now includes the
overlaps in the chunks loaded. It also calls `check_existing_chunks` to
see if all the Zarr chunks that represent each Dask chunk are present
in the existing Zarr array.  If so then this Dask chunk is skipped.

In [None]:
def load_chunks_from_czi(filename, chunk_sizes: TZYX, overlap: TZYX,
                         existing_chunks, chunk_ratios, scene=0):
    """
    A function that takes a filename and a chunking regimen and returns dict
    made up of chunks of the appropriate size.  The chunks will include
    overlaps as per `overlap`.

    Checks to see if each chunk is wholly present in the existing chunks.

    Params:
    filename: CZI file
    chunk_sizes: TZYX Contains chunk dimensions in those axes.
    overlap: TZYX Contains overlap dimensions.
    scene: CZI scene to load.
    existing_chunks: set(TZYX) Chunks present in zarr file.
    chunk_ratios: TZYX Ratio of dask chunk / zarr chunk sizes per axis.
                       Must be integral.

    Returns:
    dict[(T_c, Z_c, Y_c, X_c)] = block, where T_c, etc, are indices in the
        overall array in chunk units.
    """
    with czi.open_czi(filename) as input_file:
        bbox = input_file.total_bounding_box
        C = bbox["C"][1] - bbox["C"][0]
        T = bbox["T"][1] - bbox["T"][0]
        Z = bbox["Z"][1] - bbox["Z"][0]

        bound_rect = input_file.scenes_bounding_rectangle.get(scene, None)
        if not bound_rect:
            bound_rect = input_file.total_bounding_rectangle
        X_min = bound_rect.x
        width = bound_rect.w
        Y_min = bound_rect.y
        height = bound_rect.h

        chunk_dict = {}
        t_ind = 0
        for t in range(0, T, chunk_sizes.T):
            z_ind = 0
            for z in range(0, Z, chunk_sizes.Z):
                y_ind = 0
                for y in range(Y_min, Y_min + height, chunk_sizes.Y):
                    x_ind = 0
                    for x in range(X_min, X_min + width, chunk_sizes.X):
                        t_range = range(max(t - overlap.T, 0), min(T, t + chunk_sizes.T + overlap.T))
                        z_range = range(max(z - overlap.Z, 0), min(Z, z + chunk_sizes.Z + overlap.Z))
                        y_range = range(max(y - overlap.Y, Y_min), min(Y_min + height, y + chunk_sizes.Y + overlap.Y))
                        x_range = range(max(x - overlap.X, X_min), min(X_min + width, x + chunk_sizes.X + overlap.X))
                        #print("ranges: ",
                        #    [f"{r}({len(r)})" for r in (t_range, z_range, y_range, x_range)]
                        #)
#                        chunk_dict[t_idx, 0, z_idx, y_idx, x_idx] = load_chunk(

                        # Check output array here
                        if not check_existing_chunks((t_ind, z_ind, y_ind, x_ind), chunk_ratios, existing_chunks):
                            chunk_dict[(t_ind, z_ind, y_ind, x_ind)] = \
                                load_chunk(filename, t_range, z_range, y_range, x_range, scene=scene)
                            dask.distributed.print(f"Added {(t_ind, z_ind, y_ind, x_ind)}")
                        x_ind += 1
                    y_ind += 1
                z_ind += 1
            t_ind += 1
    return chunk_dict

In [None]:
def get_chunk_dims(file_values, chunk_sizes):
    """
    Work out how many chunks in each dimension, return
    a TZYX with these values.

    Param:
    file_values: TCZYX The dimensions of the file.
    chunk_sizes: TZYX Chunk sizes.

    Returns:
    TZYX: How many chunks in each dimension.
    """
    ch_shape = [math.ceil(f/c) for (f, c) in zip(file_values.TZYX, chunk_sizes)]
    return TZYX(*ch_shape)

## The client code
Everything beyond this point is expected to run in the Dask client only.

First we have all the parameters we need to set to process this image.

| Parameter | Note |
|-----------|------|
|`filename` | Path to CZI input file |
| `?_chunk_size` | Size of Dask chunks in the TZYX axes |
| `overlaps` | The overlaps, in pixels, on each axis (TZYX) |
| `output_filename` | The output Zarr array filename. By default this will be a directory |
| `output_chunks` | If creating a new Zarr array, this will be its chunk sizes |

Note that the Dask chunks must be integral multiples of the Zarr chunks for this simple workflow
(checked below).

The Zarr chunks should be small-ish.  In particular there's an issue where if the Zarr
chunk is >2GiB then the default compressor will fail.

In [None]:
# PARAMETERS
# Input file
# filename = "/rds/project/rds-1FbiQayZlSY/data/Millie/Timepoint3-02-Embryo3-Lattice Lightsheet.czi"
# filename = "/rds/project/rds-1FbiQayZlSY/data/Millie/Millie 24Aug23 raw light sheet data/Timepoint3-02.czi"
filename = "../../../data/2021-02-25-tulip_Airyscan.czi"
scene = 0
# Dask chunks
t_chunk_size = 1
z_chunk_size = 7
y_chunk_size = 250
x_chunk_size = 250
# Overlaps
overlaps = TZYX(0, 0, 16, 16)
# Output file
# output_filename = "/rds/project/rds-1FbiQayZlSY/data/Millie/Timepoint3-denoised-notover"
output_filename = "/tmp/denoised_zarr"
output_chunks=TCZYX(1, 1, 1, 250, 250)

Some diagnostic info about the image, then compute the number of chunks in
each axis. Note that the "last" chunk in each axis may be smaller than the
chunk size. Take care that the denoising algorithm can cope with these small
chunks (i.e. I think it fails if the image is only a few pixels wide in
an axis).

In [None]:
chunk_sizes = TZYX(t_chunk_size, z_chunk_size, y_chunk_size, x_chunk_size)
czi_file = czi.CziReader(filename)
file_array = get_czi_scene_shape(filename, scene)
print(f"scenes_bounding_rectangle: {czi_file.scenes_bounding_rectangle}")
print(f"\ntotal_bounding_box: {czi_file.total_bounding_box}")

chunk_array = get_chunk_dims(file_array, chunk_sizes)

We try to open the Zarr array for reading. If this fails then we
create and open a new Zarr array instead.

In [None]:
# Check for existence of output array.  Open it if it exists to get
# list of existing blocks.
try:
    output_zarr = zarr.open_array(
        output_filename,
        mode="r",
    )
except zarr.errors.ArrayNotFoundError:
    output_zarr = zarr.create(
        store=output_filename,
        shape=file_array,
        dtype=np.uint16,
        chunks=output_chunks,
    )

We make some simple checks on relative chunk sizes here. Note that these
are not *required* by the workflow, but some extra care or coding would
need to be in place if these checks are not true.

What we're trying to avoid is any Dask chunk sharing a Zarr chunk with
another Dask chunk.

Let's look at an example where we have a single axis image of 1000 pixels.
We split it into 10 Dask chunks of size 100 pixels.
We create our output Zarr array with chunks of 50 pixels.  This is fine,
because each Dask chunk completely fits into two Zarr chunks, and no Zarr
chunk holds more than one Dask chunk.  If we made our Zarr chunks 30 pixels
then we'd have a problem because Zarr chunk 3 would span Dask chunks 0 and 1.
When we come to write the Dask chunks they may overwrite each other's data
in Zarr chunk 3.

To avoid this we can ensure the Dask chunks are integral multiples
of the Zarr chunk sizes. If this is not desirable or possible we can
create the Zarr array with process synchronization as per [the Zarr
docs](https://zarr.readthedocs.io/en/support-v2/tutorial.html#parallel-computing-and-synchronization).
This will work as long as the underlying filesystem supports file
locking.

In [None]:
# Check sizes and that zarr chunks are integral multiples of dask chunks.
# Might also be worth checking that the dask array will fit in the zarr array!
output_zarr_chunks = TCZYX(*output_zarr.chunks)
for f_ch, d_ch in zip(output_zarr_chunks.TZYX, chunk_sizes):
    if f_ch > d_ch:
        print(chunk_sizes)
        print(output_zarr.chunks)
        raise ValueError("Dask chunks smaller than output chunks")
    if d_ch % f_ch != 0:
        print(chunk_sizes)
        print(output_zarr.chunks)
        raise ValueError("Dask chunks not multiples of output chunks")
        
# Get ratios of dask to file chunks.
chunk_ratios = TZYX(*[ d_c // f_c for (d_c, f_c) in zip(chunk_sizes, output_chunks.TZYX) ])

We examine the Zarr array for already written (Zarr) chunks. A nice feature of Zarr
is that if a chunk is present in the array then it will be complete.  That is,
there are no half-written chunks in the array.  That isn't to say that
all the Zarr chunks that comprise a Dask chunk are present! We nake a `set`
of the chunk indices.

In [None]:
# Make set of existing chunks
existing_chunks = {
    TCZYX(*map(int, k.split('.'))).TZYX for k in output_zarr.store.keys() if k != ".zarray"
}

### Start creating the Dask task graph

We first call `load_chunks_from_czi` which returns a `dict`
with keys of Dask chunk indices. The values are the Delayed
return values from `load_chunk`, which will eventually be
`ndarray`s.

In [None]:
# Load the chunks.
data_array = load_chunks_from_czi(filename, chunk_sizes, overlaps,
                                  existing_chunks, chunk_ratios)

Now the next steps are similar. We just loop over the items in the
dict and perform computations on them. As our computations
(`process_block`, `trim`, and `save_block_to_zarr`) are all
`@dask.delayed` this builds up the task graph without running
the computation.

In [None]:
# Denoise each chunk.
blocked_array = {
        k: process_block(v) for (k, v) in data_array.items()
}

In [None]:
# Trim the overlap from the chunks.
trimmed_array = {
    b_ind: trim(block, b_ind, chunk_array, overlaps)
        for (b_ind, block) in blocked_array.items()
}

In [None]:
# Save the data to the output zarr.
final_result = [
    save_block_to_zarr(block, output_filename, b_ind, chunk_sizes)
        for (b_ind, block) in trimmed_array.items()
]

Finally we end up with a collection of Delayed objects. We pass these to `dask.compute`
_en masse_, which submits them _all_ to the workers for computation.

In [None]:
# Start the actual calculation!  Final result will be an array
# of `None`, one for each chunk.
dask.compute(*final_result)

## Extending this work

The above steps from "Start creating the Dask task graph" form the backbone of
the Dask workflow. You can insert extra steps, or substitute existing ones.

For example, if we decide to use some other input file format instead
of CZI we can make some minor edits to `load_chunks_from_czi` to remove
the reading of CZI file sizes and
re-implement `load_chunk` so it reads the new file format and returns
a Delayed `ndarray` corresponding to the given ranges.
One thing to note here is that we're passing around
a _filename_ rather than a Python `File` type (or similar). This is because
each worker may be on a different node, and one machine's file descriptor
or file handle will not mean anything on another machine. Only filepaths
are truly portable here. You have to imagine that a worker has just
been given `load_chunk` with its parameters and no other context.

The workflow doesn't have to be linear. We can branch off at any point.

For example, let's say we want to perform two computations on our chunks
once we've loaded them.  Perhaps the second job needs to know
the (Dask chunk) index of the block. We might do something like:

In [None]:
# Load the chunks.
data_array = load_chunks_from_czi(filename, chunk_sizes, overlaps,
                                  existing_chunks, chunk_ratios)
# Denoise each chunk.
blocked_array = {
        k: process_block(v) for (k, v) in data_array.items()
}
# Do some other work on each chunk. <<- NEW!
other_work_array = {
    k: other_work(k, v) for (k, v) in data_array.items()
}

We could then save `other_work_array` in the same way we do `final_result`, or
combine its chunks with `blocked_array`, or whatever we want.

Or we might decide that treating the input image as a `Dask.Array` is more useful. This is the approach
we initially used, with `map_overlap`. The above Delayed approach is more general, but the Array approach
has the advantage that each chunk is now explicitly
a part of an overall array. However Dask expects you to use more array-like functions to process
your data.

One approach to do this was to leave the `load_chunk` function unchanged (returning a Delayed `ndarray`)
but have the `load_chunks_from_czi` do something like:

In [None]:
import dask.array as da
def load_chunks_from_czi_alternate(filename, t_chunk_size, z_chunk_size, y_chunk_size, x_chunk_size, scene=0):
    data_t = []
    with czi.open_czi(filename) as input_file:
        bbox = input_file.total_bounding_box
        C = bbox["C"][1] - bbox["C"][0]
        T = bbox["T"][1] - bbox["T"][0]
        Z = bbox["Z"][1] - bbox["Z"][0]

        bound_rect = input_file.scenes_bounding_rectangle.get(scene, None)
        if not bound_rect:
            bound_rect = input_file.total_bounding_rectangle
        X_min = bound_rect.x
        width = bound_rect.w
        Y_min = bound_rect.y
        height = bound_rect.h

        for t in range(0, T, t_chunk_size):
            data_z = []
            for z in range(0, Z, z_chunk_size):
                data_y = []
                for y in range(Y_min, Y_min + height, y_chunk_size):
                    data_x = []
                    for x in range(X_min, X_min + width, x_chunk_size):
                        x_data_array = da.from_delayed(
                            load_chunk(
                                filename,
                                range(t, min(T, t + t_chunk_size)),
                                range(z, min(Z, z + z_chunk_size)),
                                range(y, min(Y_min + height, y + y_chunk_size)),
                                range(x, min(X_min + width, x + x_chunk_size)),
                                scene=scene,
                            ),
                            shape=(
                                min(t_chunk_size, T-t),
                                C,
                                min(z_chunk_size, Z-z),
                                min(y_chunk_size, Y_min + height - y),
                                min(x_chunk_size, X_min + width - x),
                            ),
                            meta=np.array((), dtype=np.uint16),
                        )
                        data_x.append(x_data_array)
                    y_data_array = da.concatenate(data_x, axis=4)
                    data_y.append(y_data_array)
                z_data_array = da.concatenate(data_y, axis=3)
                data_z.append(z_data_array)

            
            t_data_array = da.concatenate(data_z, axis=2)
            data_t.append(t_data_array)
    return da.concatenate(data_t, axis=0)

In [None]:
dask_array = load_chunks_from_czi_alternate(filename, t_chunk_size, z_chunk_size, y_chunk_size, x_chunk_size)

In [None]:
dask_array

This uses `dask.array.from_delayed` to create a new dask Array from each
Delayed `ndarray`, then concatenates these into the whole array. Note this
Array is never stored in one place!

We can now use `map_overlap`, `map_blocks`, or `reduction`, and so on.  E.g.
the original workflow used `map_overlap` like this:

In [None]:
depth = (0, 0, 1, 1, 1)
blocked_array_alternate = dask_array.map_overlap(
    process_block,        # name of function to perform on each chunk
    depth=depth,          # the overlaps
    boundary="none",      # do nothing special at the boundaries
    allow_rechunk=False,  # leave chunking as is
    meta=np.array((), dtype=np.float32), # an example of the type of data that will be returned
    name="process_block", # a descriptive name to help us
)

In [None]:
blocked_array_alternate.visualize()

An important warning here is that this code won't run! You will need to
redefine `process_block` _without_ the `@dask.delayed` decorator.
This is because `map_overlap` (and `map_blocks`, etc) run the
function they're given as Delayed already, so if you give them a
delayed function you get double delayed, which gives unhelpful
errors!