# ZARR reading with Dask Client

[Link to example dataset](https://imagesc.zulipchat.com/user_uploads/16804/85qPFC9O85gLhNmF5KLdqtUx/bsd_val.zarr.zip) - copy it under `./data/` and unzip it.


In [2]:

from itertools import islice
from pathlib import Path
from typing import List, Tuple, Union, Optional, Callable, Dict, Generator
import time
import numpy as np
import zarr
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from dask.distributed import Client, get_client

from timeit import timeit, time


In [3]:
import numpy as np
from typing import List, Tuple, Union
import dask

@dask.delayed
def load_zarr(arr: np.ndarray,
              patch_size: Union[List[int], Tuple[int, ...]],
              i: int,
              x: int,
              y: int) -> np.ndarray:
    patch = arr[i,
        y - patch_size[1] // 2 : y + patch_size[1] // 2,
        x - patch_size[0] // 2 : x + patch_size[0] // 2]
    return patch


def extract_patches_random(arr: np.ndarray,
                           patch_size: Union[List[int], Tuple[int, ...]],
                           num_patches: int) -> List[np.ndarray]:
    """
    Extract a specified number of patches from an array in a random manner.

    Parameters
    ----------
    arr : np.ndarray
        Input array from which to extract patches.
    patch_size : Tuple[int, ...]
        Patch sizes in each dimension.
    num_patches : int
        Number of patches to return.

    Returns
    -------
    List[np.ndarray]
        List of randomly selected patches.
    """

    rng = np.random.default_rng()
    patches = []
    patch_centers_x = np.random.randint(low=patch_size[0] // 2,
                                        high=arr.shape[-1] - patch_size[0] // 2,
                                        size=num_patches)
    patch_centers_y = np.random.randint(low=patch_size[1] // 2,
                                        high=arr.shape[-2] - patch_size[1] // 2,
                                        size=num_patches)
    slice_indeces = np.random.randint(low=0, high=arr.shape[0], size=num_patches)

    for i, x, y in zip(slice_indeces, patch_centers_x, patch_centers_y):
        patch = load_zarr(arr, patch_size, i, x, y)


        # patch = arr[i,
        #     y - patch_size[1] // 2 : y + patch_size[1] // 2,
        #     x - patch_size[0] // 2 : x + patch_size[0] // 2]
        patches.append(patch)

    # compute each patch
    patches = dask.compute(*patches)
    return np.stack(patches)


In [4]:
try:
    client = get_client()
except ValueError:
    client = Client()
client.dashboard_link

'http://127.0.0.1:8787/status'

In [9]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_extraction_method: str,
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            mean: Optional[float] = None,
            std: Optional[float] = None,
            patch_transform: Optional[Callable] = None,
            patch_transform_params: Optional[Dict] = None,
            num_load_at_once: int = 20,
            n_shuffle_coordinates: int = 20,
    ) -> None:
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.num_load_at_once = num_load_at_once
        self.n_shuffle_coordinates = n_shuffle_coordinates

        self.sample = zarr.open(data_path, mode="r")

    def __len__(self):
        return self.n_shuffle_coordinates * self.num_load_at_once

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()

        future = client.submit(extract_patches_random,
                               self.sample,
                               self.patch_size,
                               self.num_load_at_once)

        data_in_memory = extract_patches_random(self.sample,
                                                self.patch_size,
                                                self.num_load_at_once)

        for _ in range(self.n_shuffle_coordinates):
            data_in_memory = future.result()
            future = client.submit(extract_patches_random,
                                   self.sample,
                                   self.patch_size,
                                   self.num_load_at_once)

            for j in range(len(data_in_memory)):
                yield data_in_memory[j]

test_path = Path('.') / 'data' / 'huge.zarr/0'
patch_size = (64, 64)

dataset = ZarrDataset(
    data_path=test_path,
    patch_extraction_method='random',
    patch_size=patch_size,
    num_load_at_once=20,
    n_shuffle_coordinates=100
)

dl = DataLoader(dataset, batch_size=1, num_workers=0, prefetch_factor=None)

In [10]:
from tqdm import tqdm

print(len(dl))
for X in tqdm(dl):
    np.matmul(X, X)

2000


100%|██████████| 2000/2000 [00:21<00:00, 91.03it/s]


In [39]:
enumerate(dl)

<enumerate at 0x7f4dd9348270>

In [40]:
# turn previous for loop into function for timeit to work
def iterate_dl(dl):
    timer = time.time()
    for i, batch in enumerate(dl):
        print(i)
        start = time.time()
        b = batch.shape

        print( (time.time() - timer)/(i + 1))
    return (time.time() - timer)/(i + 1)

# timeit and add counter of iterations
iterate_dl(dl)

KeyboardInterrupt: 

In [65]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_extraction_method: str,
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            mean: Optional[float] = None,
            std: Optional[float] = None,
            patch_transform: Optional[Callable] = None,
            patch_transform_params: Optional[Dict] = None,
            num_load_at_once: int = 20,
            n_shuffle_coordinates: int = 20,
    ) -> None:
        self.data_path = Path(data_path)
        self.patch_extraction_method = patch_extraction_method
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.mean = mean
        self.std = std
        self.patch_transform = patch_transform
        self.patch_transform_params = patch_transform_params
        self.num_load_at_once = num_load_at_once
        self.n_shuffle_coordinates = n_shuffle_coordinates

        self.sample = read_zarr(self.data_path)

    def __len__(self):
        return self.n_shuffle_coordinates * self.num_patches

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1

        future = client.submit(extract_patches_random,
                               self.sample,
                               self.patch_size,
                               self.num_load_at_once)

        data_in_memory = extract_patches_random(self.sample,
                                                self.patch_size,
                                                self.num_load_at_once)

        print("here")


        for _ in range(self.n_shuffle_coordinates):

            data_in_memory = future.result()


            future = client.submit(extract_patches_random,
                                   self.sample,
                                   self.patch_size,
                                   self.num_load_at_once)


            futures = [client.submit(extract_patches_random,
                            self.sample,
                            self.patch_size,
                            self.num_load_at_once)
                        for _ in range(num_workers)]


            # # Gather results from all workers
            data_in_memory2 = client.gather(futures)


            for j in range(len(data_in_memory)):
                yield data_in_memory[j]
7
dataset = ZarrDataset(
    data_path=test_path,
    patch_extraction_method='random',
    patch_size=patch_size,
    num_load_at_once=20,
    n_shuffle_coordinates=100
)

dl = DataLoader(dataset, batch_size=1, num_workers=5, prefetch_factor=None)

In [66]:
next(iter(dl))

here
here
here
here
here


KeyboardInterrupt: 

In [62]:
from tqdm import tqdm

for X in tqdm(dl):
    np.matmul(X, X)

0it [02:15, ?it/s]


KeyboardInterrupt: 