# DaskDataSet Example

Creates a ~2.5 GB raw dataset, loads it as a chunked Dask array and then runs sig and nav sum UDFs on it.

The only new code in LiberTEM is the Dask array wrapper, it wasn't necessary to modify the UDFRunner or other endpoints.

In [None]:
import os
import math
from time import perf_counter
from contextlib import contextmanager
import pathlib
import numpy as np
import distributed
import dask
import dask.array as da
import matplotlib.pyplot as plt

In [None]:
import libertem.api as lt
from libertem.executor.dask import DaskJobExecutor
from libertem.io.dataset.dask import DaskDataSet
from libertem.udf.sum import SumUDF
from libertem.udf.sumsigudf import SumSigUDF

### Util functions

In [None]:
def make_raw_ds(fpath, dtype, sig_shape, nav_shape, div=0):
    assert sig_shape[0] == sig_shape[1]
    dim = sig_shape[0]
    ramp = np.linspace(0, 1., num=dim, endpoint=True, dtype=dtype)
    base_array = ramp[:, np.newaxis] * ramp[np.newaxis, :]
    if div > 0:
        stripe = dim // div
        linear_mask = np.zeros((dim,), dtype=bool)
        for s in range(stripe//2, dim, 2*stripe):
            linear_mask[s:s+stripe] = True
        base_array[linear_mask,:] = 0
        base_array[:, linear_mask] = 0
    with pathlib.Path(fpath).open('wb') as fp:
        for idx in range(np.prod(nav_shape)):
            fp.write((base_array * idx).data)  
            
@contextmanager
def timer(msg=None):
    start = perf_counter()
    yield
    if msg is None:   
        print(f'{perf_counter()-start:.3f} s')
    else:
        print(f'{msg} - {perf_counter()-start:.3f} s')

### Setup Dask/Distributed

In [None]:
with timer('Create Dask Scheduler'):
    client = distributed.Client()

### Parameters and make dataset

In [None]:
path = pathlib.Path('./ds.raw').absolute()
nav_shape = (70,128)
sig_shape = (256,256)
shape = nav_shape + sig_shape
dtype = np.float32
blocksize = 8
if not path.is_file():
    make_raw_ds(path, np.float32, sig_shape, nav_shape, div=8)

### da.array creation

In [None]:
def mmap_load_chunk(filename, shape, dtype, offset, sl):
    data = np.memmap(filename, mode='r', shape=shape, dtype=dtype, offset=offset)
    return data[sl]


def load_chunk(filename, shape, dtype, offset, sl):
    #offset is not supported with this nb read function except for an integer number of items
    dtype_bytes = np.dtype(dtype).itemsize
    macroframe_itemsize = math.prod(shape[1:])
    start_item = sl.start *  macroframe_itemsize
    end_item = sl.stop * macroframe_itemsize
    np_shape = (sl.stop - sl.start,) + shape[1:]
    with filename.open('rb') as fp:
        return np.fromfile(fp, offset=start_item * dtype_bytes, dtype=dtype, count=end_item - start_item).reshape(np_shape)


def mmap_dask_array(filename, shape, dtype, offset=0, blocksize=8):
    if False or os.name == 'nt':
        load = dask.delayed(load_chunk)
    else:
        load = dask.delayed(mmap_load_chunk)
    chunks = []
    for index in range(0, shape[0], blocksize):
        # Truncate the last chunk if necessary
        chunk_size = min(blocksize, shape[0] - index)
        chunk = dask.array.from_delayed(
            load(
                filename,
                shape=shape,
                dtype=dtype,
                offset=offset,
                sl=slice(index, index + chunk_size)
            ),
            shape=(chunk_size, ) + shape[1:],
            dtype=dtype
        )
        chunks.append(chunk)
    return chunks, da.concatenate(chunks, axis=0)

In [None]:
chunks, d_arr = mmap_dask_array(
    filename=path,
    shape=shape,
    dtype=dtype,
    blocksize=blocksize
)
d_arr

### Load LiberTEM context and the DaskDataSet

In [None]:
executor = DaskJobExecutor(client)
ctx = lt.Context(executor=executor)

In [None]:
ds = DaskDataSet(d_arr, nav_shape=nav_shape, sig_shape=sig_shape)    
ds.initialize(executor)

#### Try to warm up the filecache

In [None]:
with path.open('rb') as fp:
    array = np.fromfile(fp, dtype=dtype).reshape(shape)

#### Run UDFs

In [None]:
sum_udf = SumUDF()
sigsum_udf = SumSigUDF()

In [None]:
with timer('Run UDFs'):
    res = ctx.run_udf(ds, [sum_udf, sigsum_udf])

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 6))
axs[0].imshow(res[0]['intensity'].data);
axs[0].set_title('Sum over nav axes');
axs[1].imshow(res[1]['intensity'].data);
axs[1].set_title('Sum over sig axes');

### Numpy-only, whole file read

In [None]:
with timer('Numpy single read'):
    with path.open('rb') as fp:
        array = np.fromfile(fp, dtype=dtype).reshape(shape)
    xx = array.sum(axis=(0, 1))
    yy = array.sum(axis=(2, 3))

### Numpy-only, partitioned reads

In [None]:
with timer('Numpy partitioned read'):
    for index in range(0, shape[0], blocksize):
        chunk_size = min(blocksize, shape[0] - index)
        array = load_chunk(
            path,
            shape=shape,
            dtype=dtype,
            offset=0,
            sl=slice(index, index + chunk_size))
        xx = array.sum(axis=(0, 1))
        yy = array.sum(axis=(2, 3))