diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md new file mode 100644 index 0000000..fdeb3f9 --- /dev/null +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -0,0 +1,177 @@ +```python +import zarrdataset as zds + +import torch +from torch.utils.data import DataLoader +``` + + +```python +# These are images from the Image Data Resource (IDR) +# https://idr.openmicroscopy.org/ that are publicly available and were +# converted to the OME-NGFF (Zarr) format by the OME group. More examples +# can be found at Public OME-Zarr data (Nov. 2020) +# https://www.openmicroscopy.org/2020/11/04/zarr-data.html + +filenames = [ + "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr" +] +``` + + +```python +import random +import numpy as np + +# For reproducibility +np.random.seed(478963) +torch.manual_seed(478964) +random.seed(478965) +``` + +## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) + +Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. +Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size. + + +```python +patch_size = dict(Y=128, X=128) +pad = dict(Y=16, X=16) +patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) +``` + +Create a dataset from the list of filenames. All those files should be stored within their respective group "0". + +Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly + + +```python +image_specs = zds.ImagesDatasetSpecs( + filenames=filenames, + data_group="4", + source_axes="TCZYX", + axes="YXC", + roi="0,0,0,0,0:1,-1,1,-1,-1" +) + +my_dataset = zds.ZarrDataset(image_specs, + patch_sampler=patch_sampler, + return_positions=True) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32. + + +```python +import torchvision + +img_preprocessing = torchvision.transforms.Compose([ + zds.ToDtype(dtype=np.float32), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(127, 255) +]) + +my_dataset.add_transform("images", img_preprocessing) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [('images',)] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +## Create a DataLoader from the dataset object + +ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module. + + +```python +my_dataloader = DataLoader(my_dataset, num_workers=0) +``` + + +```python +import dask.array as da +import numpy as np +import zarr + +z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r") + +H = z_arr.shape[-2] +W = z_arr.shape[-1] + +pad_H = (128 - H) % 128 +pad_W = (128 - W) % 128 +z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128)) +z_prediction +``` + + + + + + + + +Set up a simple model for illustration purpose + + +```python +model = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1), + torch.nn.ReLU() +) +``` + + +```python +for i, (pos, sample) in enumerate(my_dataloader): + pred_pos = ( + slice(pos[0, 0, 0].item() + 16, + pos[0, 0, 1].item() - 16), + slice(pos[0, 1, 0].item() + 16, + pos[0, 1, 1].item() - 16) + ) + pred = model(sample) + z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16] +``` + +## Visualize the result + + +```python +import matplotlib.pyplot as plt + +plt.subplot(2, 1, 1) +plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1)) +plt.subplot(2, 1, 2) +plt.imshow(z_prediction) +plt.show() +``` diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d38ca9d..2b14ed0 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -354,6 +354,33 @@ def test_PatchSampler_pad(patch_size, pad, image_collection): f"{patches_toplefts[:3]} instead.") +@pytest.mark.parametrize("patch_size, allow_incomplete_patches," + "image_collection", [ + (1024, True, IMAGE_SPECS[10]), + (1024, False, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_incomplete_patches(patch_size, allow_incomplete_patches, + image_collection): + patch_sampler = zds.PatchSampler( + patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + expected_num_patches = 1 if allow_incomplete_patches else 0 + + assert len(patches_toplefts) == expected_num_patches,\ + (f"Expected to have {expected_num_patches}, when " + f"`allow_incomplete_patches` is {allow_incomplete_patches} " + f"got {len(patches_toplefts)} instead.") + + @pytest.mark.parametrize("patch_size, axes, resample, allow_overlap," "image_collection", [ (dict(X=32, Y=32, Z=1), "XYZ", True, True, IMAGE_SPECS[10]), @@ -379,15 +406,6 @@ def test_BlueNoisePatchSampler(patch_size, axes, resample, allow_overlap, (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " f"{len(patches_toplefts)} instead.") - patches_toplefts = patch_sampler.compute_patches( - image_collection, - chunk_tlbr=chunks_toplefts[-1] - ) - - assert len(patches_toplefts) == len(patch_sampler._base_chunk_tls), \ - (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " - f"{len(patches_toplefts)} instead.") - @pytest.mark.parametrize("image_collection_mask_not2scale", [ IMAGE_SPECS[10] diff --git a/tests/test_zarrdataset.py b/tests/test_zarrdataset.py index 7112f63..17b0b74 100644 --- a/tests/test_zarrdataset.py +++ b/tests/test_zarrdataset.py @@ -116,8 +116,12 @@ def image_dataset_specs(request): @pytest.fixture(scope="function") def patch_sampler_specs(request): - patch_sampler = zds.PatchSampler(patch_size=request.param) - return patch_sampler, request.param + patch_size, allow_incomplete_patches = request.param + patch_sampler = zds.PatchSampler( + patch_size=patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + return patch_sampler, patch_size, allow_incomplete_patches @pytest.mark.parametrize("image_dataset_specs", [ @@ -142,7 +146,7 @@ def test_compatibility_no_tqdm(image_dataset_specs): try: next(iter(dataset)) - + except Exception as e: raise AssertionError(f"No exceptions where expected, got {e} " f"instead.") @@ -304,9 +308,9 @@ def test_ZarrDataset(image_dataset_specs, shuffle, return_positions, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk", [ - (IMAGE_SPECS[10], 32, True, False), - (IMAGE_SPECS[10], 32, True, True), - (IMAGE_SPECS[10], 32, False, True), + (IMAGE_SPECS[10], (32, False), True, False), + (IMAGE_SPECS[10], (32, False), True, True), + (IMAGE_SPECS[10], (32, False), False, True), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -314,7 +318,7 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -409,33 +413,40 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs", [ - (IMAGE_SPECS[10], 1024), + (IMAGE_SPECS[10], (1024, True)), + (IMAGE_SPECS[10], (1024, False)), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_greater_patch_ZarrDataset(image_dataset_specs, patch_sampler_specs): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, - patch_sampler=patch_sampler, + patch_sampler=patch_sampler ) n_samples = 0 for _ in ds: n_samples += 1 - assert n_samples == 0, ("Expected zero samples since requested patch size" - f" is greater than the image size.") + if allow_incomplete_patches: + assert n_samples > 0, ("Expected at elast one sample when patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is True.") + else: + assert n_samples == 0, ("Expected zero samples since requested patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is False.") @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers", [ - (IMAGE_SPECS[10], 32, True, False, 2, 2), - ([IMAGE_SPECS[10]] * 4, 32, True, True, 2, 3), - ([IMAGE_SPECS[10]] * 2, 32, True, True, 2, 3), + (IMAGE_SPECS[10], (32, False), True, False, 2, 2), + ([IMAGE_SPECS[10]] * 4, (32, False), True, True, 2, 3), + ([IMAGE_SPECS[10]] * 2, (32, False), True, True, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -446,7 +457,7 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, num_workers): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -514,21 +525,21 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers, repeat_dataset", [ - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 1), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 2), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 3), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 1), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 2), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_multithread_chained_ZarrDataset(image_dataset_specs, - patch_sampler_specs, - shuffle, - draw_same_chunk, - batch_size, - num_workers, - repeat_dataset): + patch_sampler_specs, + shuffle, + draw_same_chunk, + batch_size, + num_workers, + repeat_dataset): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = [zds.ZarrDataset(dataset_specs=dataset_specs, shuffle=shuffle, diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index f876a74..52df198 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -1,8 +1,6 @@ -from typing import Iterable, Union, Tuple +from typing import Iterable, Union, Tuple, List import math import numpy as np -from itertools import repeat -from functools import reduce import operator import poisson_disc @@ -23,16 +21,16 @@ class PatchSampler(object): patches (hyper-cuboids) are supported by now. If a single int is passed, that size is used for all dimensions. If an iterable (list, tuple) is passed, each value will be assigned to the corresponding axes - in `spatial_axes`, the size of `patch_size` must match the lenght of + in `spatial_axes`, the size of `patch_size` must match the lenght of `spatial_axes'. If a dict is passed, this should have at least the size - of the patch of the axes listed in `spatial_axes`. Use the same - convention as how Zarr structure array chunks in order to handle patch + of the patch of the axes listed in `spatial_axes`. Use the same + convention as how Zarr structure array chunks in order to handle patch shapes and channels correctly. stride : Union[int, Iterable[int], dict, None] Distance in pixels of the movement of the sampling sliding window. If `stride` is less than `patch_size` for an axis, patches will have an overlap between them. This is usuful in inference mode for avoiding - edge artifacts. If None is passed, the `patch_size` will be used as + edge artifacts. If None is passed, the `patch_size` will be used as `stride`. pad : Union[int, Iterable[int], dict, None] Padding in pixels added to the extracted patch at each specificed axis. @@ -43,12 +41,17 @@ class PatchSampler(object): covered by the mask. spatial_axes : str The spatial axes from where patches can be extracted. + allow_incomplete_patches : bool + Allow to retrieve patches that are smaller than the patch size. This is + the case of samples at the edge of the image that are usually smaller + than the specified patch size. """ def __init__(self, patch_size: Union[int, Iterable[int], dict], stride: Union[int, Iterable[int], dict, None] = None, pad: Union[int, Iterable[int], dict, None] = None, min_area: Union[int, float] = 1, - spatial_axes: str = "ZYX"): + spatial_axes: str = "ZYX", + allow_incomplete_patches: bool = False): # The maximum chunk sizes are used to generate a reference sampling # position array used fo every sampled chunk. self._max_chunk_size = {ax: 0 for ax in spatial_axes} @@ -120,212 +123,225 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], self._pad = {ax: pad.get(ax, 0) for ax in spatial_axes} self._min_area = min_area + self._allow_incomplete_patches = allow_incomplete_patches - def _compute_corners(self, non_zero_pos: tuple, axes: str, - limits_per_dim: Union[np.ndarray, None] = None, + def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray ) -> np.ndarray: - toplefts = np.stack(non_zero_pos).T - toplefts = toplefts.astype(np.float32) - toplefts_corners = [] - dim = len(axes) + dim = coordinates.shape[-1] factors = 2 ** np.arange(dim + 1) for d in range(2 ** dim): corner_value = np.array((d % factors[1:]) // factors[:-1], dtype=np.float32) toplefts_corners.append( - toplefts + (1 - 1e-4) * corner_value + coordinates + scale * (1 - 1e-4) * corner_value ) corners = np.stack(toplefts_corners) - if limits_per_dim is not None: - corners = np.minimum(corners, - limits_per_dim[None, None, ...] - 1e-4) - return corners - def _compute_overlap(self, corners: np.ndarray, shape: np.ndarray, - ref_shape: np.ndarray) -> Tuple[np.ndarray, - np.ndarray]: - scaled_corners = corners * shape[None, None] - tls_scaled = scaled_corners / ref_shape[None, None] - tls_idx = np.floor(tls_scaled) + def _compute_reference_indices(self, reference_coordinates: np.ndarray + ) -> Tuple[List[np.ndarray], + List[Tuple[int]]]: + reference_per_axis = list(map( + lambda coords: np.append(np.full((1, ), fill_value=-float("inf")), + np.unique(coords)), + reference_coordinates.T + )) + + reference_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, ...] + * (coord_axis[..., None] >= ref_axis[None, ...]), + axis=-1), + reference_coordinates.T, + reference_per_axis + ) + reference_idx = np.stack(tuple(reference_idx), axis=-1) + reference_idx = [ + tuple(tls_coord - 1) + for tls_coord in reference_idx.reshape(-1, len(reference_per_axis)) + ] + + return reference_per_axis, reference_idx + + def _compute_overlap(self, corners_coordinates: np.ndarray, + reference_per_axis: np.ndarray) -> Tuple[np.ndarray, + np.ndarray]: + tls_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, None, ...] + * (coord_axis[..., None] >= ref_axis[None, None, ...]), + axis=-1), + np.moveaxis(corners_coordinates, -1, 0), + reference_per_axis + ) + tls_idx = np.stack(tuple(tls_idx), axis=-1) - corners_cut = np.maximum(tls_scaled[0], tls_idx[-1]) + tls_coordinates = map( + lambda tls_coord, ref_axis: ref_axis[tls_coord], + np.moveaxis(tls_idx, -1, 0), + reference_per_axis + ) + tls_coordinates = np.stack(tuple(tls_coordinates), axis=-1) - dist2cut = np.fabs(corners - corners_cut[None]) - coverage = np.prod(dist2cut, axis=-1) + corners_cut = np.maximum(corners_coordinates[0], tls_coordinates[-1]) - # Scale the coverage to the size of the input shape - coverage *= np.prod(shape) + dist2cut = np.fabs(corners_coordinates - corners_cut[None]) + coverage = np.prod(dist2cut, axis=-1) - return coverage, tls_idx.astype(np.int64) + return coverage, tls_idx - 1 - def _compute_grid(self, chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): - mask_relative_shape = np.array( - [1 / m_scl - for ax, m_scl in mask_scale.items() - if ax in self.spatial_axes - ], - dtype=np.float32 - ) - - patch_shape = np.array( - [patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes - ], - dtype=np.float32 - ) - - # If the patch sizes are greater than the relative shape of the mask - # with respect to the input image, use the mask coordinates as - # reference to overlap the coordinates of the sampling patches. - # Otherwise, use the patches coordinates instead. - if all(map(operator.gt, patch_shape, mask_relative_shape)): - active_coordinates = np.nonzero(chunk_mask) - limits_per_dim = np.array(chunk_mask.shape) + 1 + image_size: dict, + allow_incomplete_patches: bool = False): + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + image_scale = np.array([image_size.get(ax, 1) / patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + round_fn = math.ceil if allow_incomplete_patches else math.floor + + image_blocks = [ + round_fn( + ( + min(image_size.get(ax, 1), + chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float("inf")) + - (chunk_tlbr[ax].start + if chunk_tlbr[ax].start is not None + else 0) + ) / patch_size.get(ax, 1)) + for ax in self.spatial_axes + ] - ref_axes = mask_axes + if min(image_blocks) == 0: + return [] - ref_shape = patch_shape - shape = mask_relative_shape + image_scale = np.array([patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + image_coordinates = np.array(list(np.ndindex(*image_blocks)), + dtype=np.float32) - mask_is_greater = False + image_coordinates *= image_scale - patch_ratio = [ - round(image_size[ax] / ps) - for ax, ps in zip(mask_axes, patch_shape.astype(np.int64)) - if ax in self.spatial_axes - ] + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) - if not all(patch_ratio): - return np.empty( - [0] * len(set(mask_axes).intersection(self.spatial_axes)), - dtype=np.int64 + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) ) - else: - active_coordinates = np.meshgrid( - *[np.arange(math.ceil(image_size[ax] / ps)) - for ax, ps in zip(mask_axes, patch_shape) - if ax in self.spatial_axes] - ) + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] - limits_per_dim = np.array([ - image_size[ax] / ps - for ax, ps in zip(mask_axes, patch_shape) - if ax in self.spatial_axes - ]) + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) - active_coordinates = tuple( - coord_ax.flatten() - for coord_ax in active_coordinates - ) + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] - ref_axes = "".join([ - ax for ax in self.spatial_axes if ax in mask_axes - ]) + if all(map(operator.ge, image_scale, mask_scale)): + mask_corners = self._compute_corners(mask_coordinates, mask_scale) - ref_shape = mask_relative_shape - shape = patch_shape + (reference_per_axis, + reference_idx) =\ + self._compute_reference_indices(image_coordinates) - mask_is_greater = True + (coverage, + corners_idx) = self._compute_overlap(mask_corners, + reference_per_axis) - corners = self._compute_corners(active_coordinates, axes=ref_axes, - limits_per_dim=limits_per_dim) + covered_indices = [ + reference_idx.index(tuple(idx)) + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ] - (coverage, - corners_idx) = self._compute_overlap(corners, shape, ref_shape) + patches_coverage = np.bincount(covered_indices, + weights=coverage.flatten(), + minlength=np.prod(image_blocks)) - if mask_is_greater: - # The mask ratio is greater than the patches size - mask_values = chunk_mask[tuple(corners_idx.T)].T - patches_coverage = coverage * mask_values + else: + image_corners = self._compute_corners(image_coordinates, + image_scale) - covered_tls = corners[0, ...].astype(np.int64) + (reference_per_axis, + reference_idx) = self._compute_reference_indices(mask_coordinates) - else: - # The mask ratio is less than the patches size - patch_coordinates = np.ravel_multi_index(tuple(corners_idx.T), - chunk_mask.shape) - patches_coverage = np.bincount(patch_coordinates.flatten(), - weights=coverage.flatten()) - patches_coverage = np.take(patches_coverage, patch_coordinates).T + (coverage, + corners_idx) = self._compute_overlap(image_corners, + reference_per_axis) - covered_tls = corners_idx[0, ...] + covered_indices = np.array([ + tuple(idx) in reference_idx + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ]).reshape(coverage.shape) - patches_coverage = np.sum(patches_coverage, axis=0) + patches_coverage = np.sum(covered_indices * coverage, axis=0) - # Compute minimum area covered by masked regions to sample a patch. min_area = self._min_area if min_area < 1: - min_area *= patch_shape.prod() + min_area *= np.prod(list(patch_size.values())) - minumum_covered_tls = covered_tls[patches_coverage > min_area] + minimum_covered_tls = image_coordinates[patches_coverage > min_area] + minimum_covered_tls = minimum_covered_tls.astype(np.int64) - if not mask_is_greater: - # Collapse to unique coordinates since there will be multiple - # instances of the same patch. - minumum_covered_tls = np.ravel_multi_index( - tuple(minumum_covered_tls.T), - patch_ratio, - mode="clip" - ) - - minumum_covered_tls = np.unique(minumum_covered_tls) - - minumum_covered_tls = np.unravel_index( - minumum_covered_tls, - patch_ratio - ) + return minimum_covered_tls - minumum_covered_tls = np.stack(minumum_covered_tls).T - - return minumum_covered_tls * patch_shape[None].astype(np.int64) - - def _compute_valid_toplefts(self, chunk_mask: np.ndarray, mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): - return self._compute_grid(chunk_mask, mask_axes, mask_scale, - patch_size, - image_size) + **kwargs): + return self._compute_grid(chunk_tlbr, mask, patch_size, **kwargs) - def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, - patch_size: dict, + def _compute_toplefts_slices(self, chunk_tlbr: dict, valid_mask_toplefts: np.ndarray, - chunk_tlbr: dict, + patch_size: dict, pad: Union[dict, None] = None): if pad is None: pad = {ax: 0 for ax in self.spatial_axes} - toplefts = [] - for tls in valid_mask_toplefts: - curr_tl = [] - - for ax in self.spatial_axes: - if ax in mask.axes: - tl = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)]) - br = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)] + patch_size[ax]) - - curr_tl.append((ax, slice(tl - pad[ax], - br + pad[ax]))) - - else: - curr_tl.append((ax, slice(0, 1))) - - toplefts.append(dict(curr_tl)) + toplefts = [ + {ax: slice( + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + - pad[ax], + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + patch_size[ax] + + pad[ax]) + for ax in self.spatial_axes + } + for tls in valid_mask_toplefts + ] return toplefts @@ -362,41 +378,31 @@ def compute_chunks(self, for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes } - # spatial_chunk_sizes = { - # ax: (self._patch_size[ax] - # * max(1, math.ceil(chk / self._patch_size[ax]))) - # for ax, chk in zip(image.axes, image.chunk_size) - # if ax in self.spatial_axes - # } - image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} self._max_chunk_size = { ax: (min(max(self._max_chunk_size[ax], spatial_chunk_sizes[ax]), - image_shape[ax])) + image_size[ax])) if ax in image.axes else 1 for ax in self.spatial_axes } chunk_tlbr = {ax: slice(None) for ax in self.spatial_axes} - chunk_mask = mask[chunk_tlbr] - valid_mask_toplefts = self._compute_grid( - chunk_mask, - mask.axes, - mask.scale, + chunk_tlbr, + mask, self._max_chunk_size, - image_shape + image_size, + allow_incomplete_patches=True ) chunks_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._max_chunk_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr + patch_size=self._max_chunk_size ) return chunks_slices @@ -405,29 +411,36 @@ def compute_patches(self, image_collection: ImageCollection, chunk_tlbr: dict) -> Iterable[dict]: image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - image_shape = {ax: s for ax, s in zip(image.axes, image.shape)} - chunk_size = { - ax: ((ctb.stop if ctb.stop is not None else image_shape[ax]) - - (ctb.start if ctb.start is not None else 0)) - for ax, ctb in chunk_tlbr.items() + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} + + stride = { + ax: self._stride.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes } - chunk_mask = mask[chunk_tlbr] + patch_size = { + ax: self._patch_size.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes + } + + pad = { + ax: self._pad.get(ax, 0) if image_size.get(ax, 1) > 1 else 0 + for ax in self.spatial_axes + } valid_mask_toplefts = self._compute_valid_toplefts( - chunk_mask, - mask.axes, - mask.scale, - self._stride, - chunk_size) + chunk_tlbr, + mask, + stride, + image_size=image_size, + allow_incomplete_patches=self._allow_incomplete_patches + ) patches_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._patch_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr, - pad=self._pad + patch_size=patch_size, + pad=pad ) return patches_slices @@ -457,7 +470,7 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], resample_positions=False, allow_overlap=False, **kwargs): - super(BlueNoisePatchSampler, self).__init__(patch_size) + super(BlueNoisePatchSampler, self).__init__(patch_size, **kwargs) self._base_chunk_tls = None self._resample_positions = resample_positions self._allow_overlap = allow_overlap @@ -499,35 +512,62 @@ def compute_sampling_positions(self, force=False) -> None: self._base_chunk_tls = np.zeros((1, len(self.spatial_axes)), dtype=np.int64) - def _compute_valid_toplefts(self, - chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_shape: dict): + **kwargs): self.compute_sampling_positions(force=self._resample_positions) # Filter sampling positions that does not contain any mask portion. - sampling_pos = np.hstack( - tuple( - self._base_chunk_tls[:, self.spatial_axes.index(ax), None] - if ax in self.spatial_axes else - np.zeros((len(self._base_chunk_tls), 1), dtype=np.float32) - for ax in mask_axes - ) - ) - spatial_patch_sizes = np.array([ - patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes + patch_size.get(ax, 1) for ax in self.spatial_axes ]) - mask_corners = self._compute_corners(np.nonzero(chunk_mask), - mask_axes) + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) + ) + + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] + + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) + + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] + + mask_corners = self._compute_corners(mask_coordinates, mask_scale) dist = (mask_corners[None, :, :, :] - - sampling_pos[:, None, None, :].astype(np.float32) + - self._base_chunk_tls[:, None, None, :].astype(np.float32) - spatial_patch_sizes[None, None, None, :] / 2) mask_samplable_pos, = np.nonzero( @@ -537,6 +577,6 @@ def _compute_valid_toplefts(self, ) ) - toplefts = sampling_pos[mask_samplable_pos] + toplefts = self._base_chunk_tls[mask_samplable_pos] return toplefts