In [1]:

from itertools import islice
from pathlib import Path
from typing import List, Tuple, Union, Optional, Callable, Dict, Generator

import numpy as np
import zarr
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from dask.distributed import Client, get_client


In [2]:
test_path = Path('.') / 'data' / 'test_ngff_image.zarr/0' 
patch_size = (64, 64)

In [3]:
def read_zarr(file_path: Path) -> Union[zarr.core.Array, zarr.storage.DirectoryStore, zarr.hierarchy.Group]:
    """Reads a file and returns a pointer.

    Parameters
    ----------
    file_path : Path
        pathlib.Path object containing a path to a file

    Returns
    -------
    np.ndarray
        Pointer to zarr storage

    Raises
    ------
    ValueError, OSError
        if a file is not a valid tiff or damaged
    ValueError
        if data dimensions are not 2, 3 or 4
    ValueError
        if axes parameter from config is not consistent with data dimensions
    """
    zarr_source = zarr.open(Path(file_path), mode="r")
    
    """
    if isinstance(zarr_source, zarr.hierarchy.Group):
        raise NotImplementedError("Group not supported yet")

    elif isinstance(zarr_source, zarr.storage.DirectoryStore):
        raise NotImplementedError("DirectoryStore not supported yet")

    elif isinstance(zarr_source, zarr.core.Array):
        # array should be of shape (S, (C), (Z), Y, X), iterating over S ?
        # TODO what if array is not of that shape and/or chunks aren't defined and
        if zarr_source.dtype == "O":
            raise NotImplementedError("Object type not supported yet")
        else:
            array = zarr_source
    else:
        raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")

    # TODO how to fix dimensions? Or just raise error?
    # sanity check on dimensions
    if len(array.shape) < 2 or len(array.shape) > 4:
        raise ValueError(
            f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
        )
    """
    return zarr_source

In [64]:
def extract_patches_random(arr: np.ndarray,
                           patch_size: Union[List[int], Tuple[int]]
) -> np.ndarray:
    """
    Generate patches from an array in a random manner.

    The method calculates how many patches the image can be divided into and then
    extracts an equal number of random patches.

    Parameters
    ----------
    arr : np.ndarray
        Input image array.
    patch_size : Tuple[int]
        Patch sizes in each dimension.

    Yields
    ------
    Generator[np.ndarray, None, None]
        Generator of patches.
    """

    rng = np.random.default_rng()

    n_patches_per_slice = np.ceil(np.prod(arr.shape[1:]) / np.prod(patch_size)).astype(
        int
    )
    crop_coords = rng.integers(
        0,
        np.array(arr.shape[-len(patch_size):]) - np.array(patch_size),
        size=(arr.shape[0], n_patches_per_slice, len(patch_size)),
    )
    slice_idx  = np.random.randint(0, arr.shape[0])
    patch_idx = np.random.randint(0, crop_coords.shape[1])
    #sample = [slice_idx]

    patch = arr[
        crop_coords[slice_idx, patch_idx, 0]: crop_coords[
                                                    slice_idx, patch_idx, 0
                                                ]
                                                + patch_size[0],
        crop_coords[slice_idx, patch_idx, 1]: crop_coords[
                                                    slice_idx, patch_idx, 1
                                                ]
                                                + patch_size[1],
        ]
    return patch

def _generate_patches(sample, to_load_at_same_time=10):
    patches = extract_patches_random(
        sample,
        patch_size,
    )

    list_of_patches = []
    
    for idx, patch in enumerate(patches):
        
        list_of_patches.append(patch)
        if idx > to_load_at_same_time:
            break
    
    return list_of_patches

In [45]:
client = Client()

Perhaps you already have a cluster running?
Hosting the HTTP server on port 52821 instead


In [46]:
client.dashboard_link

'http://127.0.0.1:52821/status'

In [86]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
            self,
            data_path: Union[str, Path],
            patch_extraction_method: str,
            patch_size: Optional[Union[List[int], Tuple[int]]] = None,
            num_patches: Optional[int] = None,
            mean: Optional[float] = None,
            std: Optional[float] = None,
            patch_transform: Optional[Callable] = None,
            patch_transform_params: Optional[Dict] = None,
            to_load_at_same_time: int = 10,
            loads_per_epoch: int = 2,
    ) -> None:
        self.data_path = Path(data_path)
        self.patch_extraction_method = patch_extraction_method
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.mean = mean
        self.std = std
        self.patch_transform = patch_transform
        self.patch_transform_params = patch_transform_params
        self.to_load_at_same_time = to_load_at_same_time
        self.loads_per_epoch = loads_per_epoch

        self.sample = read_zarr(self.data_path)

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        
        future = client.submit(_generate_patches, self.sample, self.to_load_at_same_time)

        for _ in range(self.loads_per_epoch):
            data_in_memory = future.result()
            future = client.submit(_generate_patches, self.sample, self.to_load_at_same_time)

            for j in range(len(data_in_memory)):
                yield data_in_memory


In [87]:
dataset = ZarrDataset(
    data_path=test_path,
    patch_extraction_method='random',
    patch_size=patch_size,
)

In [88]:
limit = 100
for x in dataset:
    print(len(x))
    limit -= 1
    if limit == 0:
        break

12
12
12
12
12
12
12
12
12
12
12
12
