In [1]:
import pprint
import shutil
import time
from pathlib import Path
import numpy as np
from typing import List, Tuple, Union, Optional, Callable, Any, Dict, Generator
from itertools import islice

import matplotlib.pyplot as plt
import zarr
import tifffile
from matplotlib.pyplot import imshow

from torch.utils.data import DataLoader, IterableDataset, get_worker_info

from careamics.dataset.dataset_utils import read_zarr
from careamics.dataset.patching import generate_patches
from careamics.utils import RunningStats



In [2]:
fast_path = '/localscratch/bsd_train.zarr/'
reg_path = '/home/igor.zubarev/data/zarr_test/bsd_train.zarr'

axes = 'SYX'
patch_size = (64, 64)

In [8]:
import urllib.request
d = urllib.request.urlretrieve('https://download.fht.org/jug/zarr_hackathon/bsd_train.zarr')


HTTPError: HTTP Error 404: Not Found

In [3]:
def extract_patches_random(
    arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
) -> Generator[np.ndarray, None, None]:
    """
    Generate patches from an array in a random manner.

    The method calculates how many patches the image can be divided into and then
    extracts an equal number of random patches.

    Parameters
    ----------
    arr : np.ndarray
        Input image array.
    patch_size : Tuple[int]
        Patch sizes in each dimension.

    Yields
    ------
    Generator[np.ndarray, None, None]
        Generator of patches.
    """

    rng = np.random.default_rng()

    n_patches_per_slice = np.ceil(np.prod(arr.shape[1:]) / np.prod(patch_size)).astype(
        int
    )
    crop_coords = rng.integers(
        0,
        np.array(arr.shape[-len(patch_size) :]) - np.array(patch_size),
        size=(arr.shape[0], n_patches_per_slice, len(patch_size)),
    )
    for slice_idx in range(crop_coords.shape[0]):
        sample = arr[slice_idx]
        for patch_idx in range(crop_coords.shape[1]):
            patch = sample[
                crop_coords[slice_idx, patch_idx, 0] : crop_coords[
                    slice_idx, patch_idx, 0
                ]
                + patch_size[0],
                crop_coords[slice_idx, patch_idx, 1] : crop_coords[
                    slice_idx, patch_idx, 1
                ]
                + patch_size[1],
            ]
            yield patch

In [4]:
class ZarrDataset(IterableDataset):
    """Dataset to extract patches from a zarr storage."""

    def __init__(
        self,
        data_path: Union[str, Path],
        axes: str,
        patch_extraction_method: str,
        patch_size: Optional[Union[List[int], Tuple[int]]] = None,
        num_patches: Optional[int] = None,
        mean: Optional[float] = None,
        std: Optional[float] = None,
        patch_transform: Optional[Callable] = None,
        patch_transform_params: Optional[Dict] = None,
    ) -> None:
        self.data_path = Path(data_path)
        self.axes = axes
        self.patch_extraction_method = patch_extraction_method
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.mean = mean
        self.std = std
        self.patch_transform = patch_transform
        self.patch_transform_params = patch_transform_params

        self.sample = read_zarr(self.data_path, self.axes)
        self.running_stats = RunningStats()

    def _generate_patches(self):
        patches = extract_patches_random(
            self.sample,
            self.patch_size,
        )

        for idx, patch in enumerate(patches):
            if self.mean is None or self.std is None:
                self.running_stats.update_mean(patch.mean())
                self.running_stats.update_std(patch.std())
            if isinstance(patch, tuple):

                patch = (patch, *patch[1:])
            else:
                patch = patch

            if self.patch_transform is not None:
                assert self.patch_transform_params is not None
                patch = self.patch_transform(patch, **self.patch_transform_params)
            if self.num_patches is not None and idx >= self.num_patches:
                return
            else:
                yield patch
        self.mean = self.running_stats.avg_mean
        self.std = self.running_stats.avg_std

    def __iter__(self):
        """
        Iterate over data source and yield single patch.

        Yields
        ------
        np.ndarray
        """
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info is not None else 0
        num_workers = worker_info.num_workers if worker_info is not None else 1
        yield from islice(self._generate_patches(), 0, None, num_workers)


In [5]:
dataset = ZarrDataset(
                data_path=reg_path,
                axes=axes,
                patch_extraction_method='random',
                patch_size=patch_size,
            )

In [6]:
dl = DataLoader(dataset, batch_size=32, num_workers=4, prefetch_factor=8)

In [None]:
for i, batch in enumerate(dl):
    print(i, batch.shape)