# ZARR reading

[Link to example dataset](https://imagesc.zulipchat.com/user_uploads/16804/85qPFC9O85gLhNmF5KLdqtUx/bsd_val.zarr.zip) - copy it under `./data/` and unzip it.


In [None]:
from itertools import islice
from pathlib import Path
from typing import List, Tuple, Union, Optional, Callable, Dict, Generator

import numpy as np
import zarr
import time
import sys

from torch.utils.data import DataLoader, IterableDataset, get_worker_info

In [120]:
def extract_patches_random(arr: np.ndarray,
                           patch_size: Union[List[int], Tuple[int]]
) -> Generator[np.ndarray, None, None]:
    """
    Generate patches from an array in a random manner.

    The method calculates how many patches the image can be divided into and then
    extracts an equal number of random patches.

    Parameters
    ----------
    arr : np.ndarray
        Input image array.
    patch_size : Tuple[int]
        Patch sizes in each dimension.

    Yields
    ------
    Generator[np.ndarray, None, None]
        Generator of patches.
    """

    rng = np.random.default_rng()

    n_patches_per_slice = np.ceil(np.prod(arr.shape[1:]) / np.prod(patch_size)).astype(
        int
    )
    crop_coords = rng.integers(
        0,
        np.array(arr.shape[-len(patch_size):]) - np.array(patch_size),
        size=(arr.shape[0], n_patches_per_slice, len(patch_size)),
    )
    for slice_idx in range(crop_coords.shape[0]):
        sample = arr[slice_idx]
        for patch_idx in range(crop_coords.shape[1]):
            patch = sample[
                    crop_coords[slice_idx, patch_idx, 0]: crop_coords[
                                                              slice_idx, patch_idx, 0
                                                          ]
                                                          + patch_size[0],
                    crop_coords[slice_idx, patch_idx, 1]: crop_coords[
                                                              slice_idx, patch_idx, 1
                                                          ]
                                                          + patch_size[1],
                    ]
            yield patch

In [121]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_extraction_method: str,
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            mean: Optional[float] = None,
            std: Optional[float] = None,
            patch_transform: Optional[Callable] = None,
            patch_transform_params: Optional[Dict] = None,
    ) -> None:
        self.data_path = Path(data_path)
        self.patch_extraction_method = patch_extraction_method
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.mean = mean
        self.std = std
        self.patch_transform = patch_transform
        self.patch_transform_params = patch_transform_params

        self.sample = zarr.open(data_path, mode="r")

    def _generate_patches(self):
        patches = extract_patches_random(
            self.sample,
            self.patch_size,
        )

        for idx, patch in enumerate(patches):

            if isinstance(patch, tuple):
                patch = (patch, *patch[1:])
            else:
                patch = patch

            if self.patch_transform is not None:
                assert self.patch_transform_params is not None
                patch = self.patch_transform(patch, **self.patch_transform_params)
            if self.num_patches is not None and idx >= self.num_patches:
                return
            else:
                yield patch

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        yield from islice(self._generate_patches(), 0, None, num_workers)

def train_loop(dataloader: DataLoader):
    for i, batch in enumerate(dataloader):
        pass

In [123]:
test_path = Path('.') / 'data' / 'bsd_val.zarr' 
train_path_fast = '/localscratch/bsd_train.zarr/'

patch_size = (64, 64)

dataset = ZarrDataset(
    data_path=test_path,
    patch_extraction_method='random',
    patch_size=patch_size,
)

dl = DataLoader(dataset, batch_size=128, num_workers=0)



In [124]:
times = []

for i, batch in enumerate(dl):
    start = time.time()
    b = batch.shape
    cur_time = time.time() - start
    times.append(cur_time)
    info = f" {cur_time * 1e6:.3f}us/step"

    print(info, end='\r')

print(f"Average time: {np.mean(times) * 1e6:.3f}us/step")

Average time: 0.000us/step


In [127]:
from timeit import timeit, time

In [128]:
# turn previous for loop into function for timeit to work
def iterate_dl(dl):
    timer = time.time()
    for i, batch in enumerate(dl):
        start = time.time()
        b = batch.shape
    return (time.time() - timer)/(i + 1)

# timeit and add counter of iterations
iterate_dl(dl)

0.08461117744445801