In [1]:
import numpy as np
import distributed
import dask
import dask.array as da
from libertem.common import Slice, Shape

In [2]:
c = distributed.Client()

In [3]:
path = r'E:\LargeData\LargeData\ER-C-1\groups\data_science\data\reference\EMPAD\BiFeO3\scan_11_x256_y256.raw'
shape = (256, 256, 130, 128)
dtype=np.float32

In [4]:
def mmap_load_chunk(filename, shape, dtype, offset, sl):
    data = np.memmap(filename, mode='r', shape=shape, dtype=dtype, offset=offset)
    return data[sl]


def mmap_dask_array(filename, shape, dtype, offset=0, blocksize=5):
    load = dask.delayed(mmap_load_chunk)
    chunks = []
    for index in range(0, shape[0], blocksize):
        # Truncate the last chunk if necessary
        chunk_size = min(blocksize, shape[0] - index)
        chunk = dask.array.from_delayed(
            load(
                filename,
                shape=shape,
                dtype=dtype,
                offset=offset,
                sl=slice(index, index + chunk_size)
            ),
            shape=(chunk_size, ) + shape[1:],
            dtype=dtype
        )
        chunks.append(chunk)
    return da.concatenate(chunks, axis=0)


In [5]:
arr = mmap_dask_array(
    filename=path,
    shape=shape,
    dtype=dtype,
    blocksize=8
)

In [6]:
flat_masks_T = np.random.random((np.prod(shape[2:]), 2))

In [7]:
da_masks_T = c.scatter(flat_masks_T)

In [8]:
def calculate_partition(partition, sl, flat_masks_T):
    '''
    Depends on unsliced sig since it doesn't have a MaskContainer to do the slicing and transposing!
    To get a distributed scattered "MaskContainer" we may have to precompute the tiles it needs and submit them as futures?
    '''
    if hasattr(partition, 'compute'):
        partition = partition.compute()
    flat_nav = np.prod(partition.shape[:2])
    flat_sig = np.prod(partition.shape[2:])
    flat_partition = partition.reshape((flat_nav, flat_sig))
    result = flat_partition @ flat_masks_T
    debug = (type(partition), type(sl), type(flat_masks_T))
    return (sl, result.reshape((*partition.shape[:2], flat_masks_T.shape[-1])), debug)

In [9]:
assert all(len(ch) == 1 for ch in arr.chunks[1:])

In [10]:
%%time
ref = (arr.reshape((np.prod(shape[:2]), np.prod(shape[2:]))) @ flat_masks_T).reshape((*shape[:2], flat_masks_T.shape[-1])).compute()

Wall time: 6.44 s


In [11]:
result = np.zeros((*shape[:2], flat_masks_T.shape[-1]), dtype=np.result_type(dtype, flat_masks_T.dtype))

## Use Dask array blocks

They are transmitted as-is to the function. The function has to compute it itself, apparently there's no specific interoperability between Dask arrays and futures.

In [12]:
%%time
start = 0
futures = []
for index, size  in enumerate(arr.chunks[0]):
    partition = arr.blocks[index]
    sl = Slice(origin=(start, 0, 0, 0), shape=Shape((size, *shape[1:]), sig_dims=2))
    fut = c.submit(
        calculate_partition,
        partition=partition,
        sl=sl,
        flat_masks_T=da_masks_T,
    )
    futures.append(fut)
    start += size

for fut, res in distributed.as_completed(futures, with_results=True):
    res_sl, res_data, debug = res
    real_res_sl = res_sl.get(nav_only=True)
    result[real_res_sl] = res_data
    print(debug)

(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.ndarray'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memma

In [13]:
np.allclose(ref, result)

True

## Submit the partition calculation as a future

In this example it is really lightweight since it just memory-maps and slices an array.

In [14]:
%%time
start = 0
futures = []
for index, size  in enumerate(arr.chunks[0]):
    partition = c.submit(
        mmap_load_chunk,
        filename=path,
        shape=shape,
        dtype=dtype,
        offset=0,
        sl=slice(start, start+size)
    )
    sl = Slice(origin=(start, 0, 0, 0), shape=Shape((size, *shape[1:]), sig_dims=2))
    fut = c.submit(
        calculate_partition,
        partition=partition,
        sl=sl,
        flat_masks_T=da_masks_T,
    )
    futures.append(fut)
    start += size

for fut, res in distributed.as_completed(futures, with_results=True):
    res_sl, res_data, debug = res
    real_res_sl = res_sl.get(nav_only=True)
    result[real_res_sl] = res_data
    print(debug)

(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.ndarray'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.ndarray'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.ndarray'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.memmap'>, <class 'libertem.common.slice.Slice'>, <class 'numpy.ndarray'>)
(<class 'numpy.nda

In [15]:
np.allclose(ref, result)

True