# ZARR reading with Dask Client and future

In [1]:
import dask.array as da

online_path = "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0062A/6001240.zarr/0"
local_path = "data/6001240.zarr"

In [2]:
data = da.from_zarr(online_path).rechunk((1, 50, 100, 100)).astype('int16')
data

Unnamed: 0,Array,Chunk
Bytes,67.09 MiB,0.95 MiB
Shape,"(2, 236, 275, 271)","(1, 50, 100, 100)"
Dask graph,90 chunks in 5 graph layers,90 chunks in 5 graph layers
Data type,int16 numpy.ndarray,int16 numpy.ndarray
"Array Chunk Bytes 67.09 MiB 0.95 MiB Shape (2, 236, 275, 271) (1, 50, 100, 100) Dask graph 90 chunks in 5 graph layers Data type int16 numpy.ndarray",2  1  271  275  236,

Unnamed: 0,Array,Chunk
Bytes,67.09 MiB,0.95 MiB
Shape,"(2, 236, 275, 271)","(1, 50, 100, 100)"
Dask graph,90 chunks in 5 graph layers,90 chunks in 5 graph layers
Data type,int16 numpy.ndarray,int16 numpy.ndarray


In [3]:
data.to_zarr(local_path)

In [4]:
patch_size = (2, 10, 64, 64)
small_slice = tuple([slice(0, i) for i in patch_size])

In [5]:

from itertools import islice
from pathlib import Path
from typing import List, Tuple, Union, Optional, Callable, Dict, Generator
import time
import numpy as np
import zarr
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from dask.distributed import Client, get_client

from timeit import timeit, time


In [6]:
import numpy as np
from typing import List, Tuple, Union
import dask


@dask.delayed
def load_zarr(arr: np.ndarray,
              patch_positions,
              patch_size: Union[List[int], Tuple[int, ...]]
              ) -> np.ndarray:


    # create slices for each dimension
    slices = []
    for i, (center, dimension) in enumerate(zip(patch_positions, patch_size)):
        if center is None:
            slices.append(slice(None))
        else:
            slices.append(slice(center, center + dimension))

    # load patch
    patch = arr[tuple(slices)]
    return patch

def extract_patches_random(arr: np.ndarray,
                           patch_size: Union[List[int], Tuple[int, ...]],
                           num_patches: int) -> List[np.ndarray]:
    """
    Extract a specified number of patches from an array in a random manner.

    Parameters
    ----------
    arr : np.ndarray
        Input array from which to extract patches.
    patch_size : Tuple[int, ...]
        Patch sizes in each dimension.
    num_patches : int
        Number of patches to return.

    Returns
    -------
    List[np.ndarray]
        List of randomly selected patches.
    """

    patch_centers = []
    for i, dimension in enumerate(patch_size):
        if dimension == arr.shape[i]:
            patch_centers.append([None]*num_patches)
        else:
            patch_centers.append(np.random.randint(low=0,
                                               high=arr.shape[i] - dimension,
                                               size=num_patches))
    patch_centers = np.array(patch_centers).T

    patches = []
    for patch in patch_centers:
        patch = load_zarr(arr, patch, patch_size)
        patches.append(patch)

    patches = dask.compute(*patches)
    return np.stack(patches)


In [7]:
try:
    client = get_client()
except ValueError:
    client = Client()
client.dashboard_link

'http://127.0.0.1:8787/status'

In [8]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            num_load_at_once: int = 20,
            n_shuffle_coordinates: int = 20,
    ) -> None:
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.num_load_at_once = num_load_at_once
        self.n_shuffle_coordinates = n_shuffle_coordinates

        self.sample = zarr.open(data_path, mode="r")

    def __len__(self):
        return self.n_shuffle_coordinates * self.num_load_at_once

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        future = client.submit(extract_patches_random,
                               self.sample,
                               self.patch_size,
                               self.num_load_at_once)

        data_in_memory = extract_patches_random(self.sample,
                                                self.patch_size,
                                                self.num_load_at_once)

        for _ in range(self.n_shuffle_coordinates):
            data_in_memory = future.result()
            future = client.submit(extract_patches_random,
                                   self.sample,
                                   self.patch_size,
                                   self.num_load_at_once)

            for j in range(len(data_in_memory)):
                # pop and yield single patch
                patch = data_in_memory[j]
                yield patch


In [9]:
# around 5 seconds

from tqdm import tqdm

dataset = ZarrDataset(
    data_path=local_path,
    patch_size=patch_size,
    num_load_at_once=20,
    n_shuffle_coordinates=100
)

dl = DataLoader(dataset, batch_size=1, num_workers=0, prefetch_factor=None)


for X in tqdm(dl):
    X = np.array(X)

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [00:05<00:00, 366.88it/s]


In [10]:
%%time
# around 1 minutes

import dask.array as da
complete_download = da.from_zarr(local_path)

for i in  tqdm(range(len(dl))):
    complete_download[small_slice].compute()

 24%|██▍       | 488/2000 [00:14<00:39, 38.69it/s]2023-11-08 16:13:10,105 - distributed.scheduler - ERROR - Couldn't gather keys: {('getitem-46feeb637397bec3bf1e2ad3d8cce1b4', 0, 0, 0, 0): 'forgotten', ('getitem-46feeb637397bec3bf1e2ad3d8cce1b4', 1, 0, 0, 0): 'forgotten'}
 47%|████▋     | 937/2000 [00:28<00:33, 31.80it/s]2023-11-08 16:13:24,146 - distributed.scheduler - ERROR - Couldn't gather keys: {('getitem-46feeb637397bec3bf1e2ad3d8cce1b4', 0, 0, 0, 0): 'waiting', ('getitem-46feeb637397bec3bf1e2ad3d8cce1b4', 1, 0, 0, 0): 'waiting'}
100%|██████████| 2000/2000 [00:57<00:00, 34.76it/s]

CPU times: user 22.8 s, sys: 2.35 s, total: 25.2 s
Wall time: 57.5 s



