In [1]:

from itertools import islice
from pathlib import Path
from typing import List, Tuple, Union, Optional, Callable, Dict, Generator
import time
import numpy as np
import zarr
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from dask.distributed import Client, get_client

from timeit import timeit, time


In [2]:
test_path = Path('.') / 'data' / 'test_ngff_image.zarr/0' 
patch_size = (64, 64)

In [3]:
def read_zarr(file_path: Path) -> Union[zarr.core.Array, zarr.storage.DirectoryStore, zarr.hierarchy.Group]:
    """Reads a file and returns a pointer.

    Parameters
    ----------
    file_path : Path
        pathlib.Path object containing a path to a file

    Returns
    -------
    np.ndarray
        Pointer to zarr storage

    Raises
    ------
    ValueError, OSError
        if a file is not a valid tiff or damaged
    ValueError
        if data dimensions are not 2, 3 or 4
    ValueError
        if axes parameter from config is not consistent with data dimensions
    """
    zarr_source = zarr.open(Path(file_path), mode="r")
    
    """
    if isinstance(zarr_source, zarr.hierarchy.Group):
        raise NotImplementedError("Group not supported yet")

    elif isinstance(zarr_source, zarr.storage.DirectoryStore):
        raise NotImplementedError("DirectoryStore not supported yet")

    elif isinstance(zarr_source, zarr.core.Array):
        # array should be of shape (S, (C), (Z), Y, X), iterating over S ?
        # TODO what if array is not of that shape and/or chunks aren't defined and
        if zarr_source.dtype == "O":
            raise NotImplementedError("Object type not supported yet")
        else:
            array = zarr_source
    else:
        raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")

    # TODO how to fix dimensions? Or just raise error?
    # sanity check on dimensions
    if len(array.shape) < 2 or len(array.shape) > 4:
        raise ValueError(
            f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
        )
    """
    return zarr_source

In [4]:
import numpy as np
from typing import List, Tuple, Union

def extract_patches_random(arr: np.ndarray,
                           patch_size: Union[List[int], Tuple[int, ...]],
                           num_patches: int) -> List[np.ndarray]:
    """
    Extract a specified number of patches from an array in a random manner.

    Parameters
    ----------
    arr : np.ndarray
        Input array from which to extract patches.
    patch_size : Tuple[int, ...]
        Patch sizes in each dimension.
    num_patches : int
        Number of patches to return.

    Returns
    -------
    List[np.ndarray]
        List of randomly selected patches.
    """

    rng = np.random.default_rng()
    patches = []
    patch_centers_x = np.random.randint(low=patch_size[0] // 2,
                                        high=arr.shape[-1] - patch_size[0] // 2,
                                        size=num_patches)
    patch_centers_y = np.random.randint(low=patch_size[1] // 2,
                                        high=arr.shape[-2] - patch_size[1] // 2,
                                        size=num_patches)
    slice_indeces = np.random.randint(low=0, high=arr.shape[0], size=num_patches)
    
    for i, x, y in zip(slice_indeces, patch_centers_x, patch_centers_y):
        patch = arr[i, 
                    y - patch_size[1] // 2 : y + patch_size[1] // 2,
                    x - patch_size[0] // 2 : x + patch_size[0] // 2]

        patches.append(patch)

    return np.stack(patches)


In [5]:
try:
    client = get_client()
except ValueError:
    client = Client()
client.dashboard_link

Perhaps you already have a cluster running?
Hosting the HTTP server on port 65013 instead


'http://127.0.0.1:65013/status'

In [6]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_extraction_method: str,
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            mean: Optional[float] = None,
            std: Optional[float] = None,
            patch_transform: Optional[Callable] = None,
            patch_transform_params: Optional[Dict] = None,
            num_load_at_once: int = 20,
            n_shuffle_coordinates: int = 20,
    ) -> None:
        self.data_path = Path(data_path)
        self.patch_extraction_method = patch_extraction_method
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.mean = mean
        self.std = std
        self.patch_transform = patch_transform
        self.patch_transform_params = patch_transform_params
        self.num_load_at_once = num_load_at_once
        self.n_shuffle_coordinates = n_shuffle_coordinates

        self.sample = read_zarr(self.data_path)

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        
        future = client.submit(extract_patches_random,
                               self.sample,
                               self.patch_size,
                               self.num_load_at_once)

        for _ in range(self.n_shuffle_coordinates):
            data_in_memory = future.result()
            future = client.submit(extract_patches_random,
                                   self.sample,
                                   self.patch_size,
                                   self.num_load_at_once)

            for j in range(len(data_in_memory)):
                yield data_in_memory[j]


In [7]:
dataset = ZarrDataset(
    data_path=test_path,
    patch_extraction_method='random',
    patch_size=patch_size,
    num_load_at_once=20,
    n_shuffle_coordinates=100
)

dl = DataLoader(dataset, batch_size=128, num_workers=0)

In [8]:
# turn previous for loop into function for timeit to work
def iterate_dl(dl):
    timer = time.time()
    for i, batch in enumerate(dl):
        start = time.time()
        b = batch.shape
    return (time.time() - timer)/(i + 1)

# timeit and add counter of iterations
iterate_dl(dl)

0.3090398907661438