diff --git a/.gitignore b/.gitignore index 5bc0292..51e576d 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +example.py # Translations *.mo diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md new file mode 100644 index 0000000..d3a6cd9 --- /dev/null +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -0,0 +1,194 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +execution: + timeout: 120 +--- + +# Integration of ZarrDataset with PyTorch's DataLoader for inference (Advanced) + +```python +import zarrdataset as zds + +import torch +from torch.utils.data import DataLoader +``` + + +```python +# These are images from the Image Data Resource (IDR) +# https://idr.openmicroscopy.org/ that are publicly available and were +# converted to the OME-NGFF (Zarr) format by the OME group. More examples +# can be found at Public OME-Zarr data (Nov. 2020) +# https://www.openmicroscopy.org/2020/11/04/zarr-data.html + +filenames = [ + "https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr" +] +``` + + +```python +import random +import numpy as np + +# For reproducibility +np.random.seed(478963) +torch.manual_seed(478964) +random.seed(478965) +``` + +## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) + +Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. +Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size by setting `allow_incomplete_patches=True`. + + +```python +patch_size = dict(Y=128, X=128) +pad = dict(Y=16, X=16) +patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) +``` + +Create a dataset from the list of filenames. All those files should be stored within their respective group "0". + +Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly + + +```python +image_specs = zds.ImagesDatasetSpecs( + filenames=filenames, + data_group="4", + source_axes="TCZYX", + axes="YXC", + roi="0,0,0,0,0:1,-1,1,-1,-1" +) + +my_dataset = zds.ZarrDataset(image_specs, + patch_sampler=patch_sampler, + return_positions=True) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32. + + +```python +import torchvision + +img_preprocessing = torchvision.transforms.Compose([ + zds.ToDtype(dtype=np.float32), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(127, 255) +]) + +my_dataset.add_transform("images", img_preprocessing) +``` + + +```python +my_dataset +``` + + + + + ZarrDataset (PyTorch support:True, tqdm support :True) + Modalities: images + Transforms order: [('images',)] + Using images modality as reference. + Using for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. + + + +## Create a DataLoader from the dataset object + +ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module. + + +```python +my_dataloader = DataLoader(my_dataset, num_workers=0) +``` + + +```python +import dask.array as da +import numpy as np +import zarr + +z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r") + +H = z_arr.shape[-2] +W = z_arr.shape[-1] + +pad_H = (128 - H) % 128 +pad_W = (128 - W) % 128 +z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128)) +z_prediction +``` + + + + + + + + +Set up a simple model for illustration purpose + + +```python +model = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1), + torch.nn.ReLU() +) +``` + + +```python +for i, (pos, sample) in enumerate(my_dataloader): + pred_pos = ( + slice(pos[0, 0, 0].item() + 16, + pos[0, 0, 1].item() - 16), + slice(pos[0, 1, 0].item() + 16, + pos[0, 1, 1].item() - 16) + ) + pred = model(sample) + z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16] +``` + +## Visualize the result + + +```python +import matplotlib.pyplot as plt + +plt.subplot(2, 1, 1) +plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1)) +plt.subplot(2, 1, 2) +plt.imshow(z_prediction) +plt.show() +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 9359764..9394850 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,8 @@ Welcome to ZarrDataset's documentation! examples/advanced_example_pytorch + examples/advanced_example_pytorch_inference + REFERENCE ========= diff --git a/tests/test_imageloaders.py b/tests/test_imageloaders.py index 366861b..b996140 100644 --- a/tests/test_imageloaders.py +++ b/tests/test_imageloaders.py @@ -163,6 +163,29 @@ def test_ImageBase_slicing(): f"{expected_selection_shape}, got {img_sel_2.shape} instead") +def test_ImageBase_padding(): + shape = (16, 16, 3) + axes = "YXC" + img = zds.ImageBase(shape, chunk_size=None, source_axes=axes, mode="image") + + random.seed(44512) + selection_1 = dict( + (ax, slice(random.randint(-10, 0), + random.randint(1, r_s + 10))) + for ax, r_s in zip(axes, shape) + ) + + expected_selection_shape = tuple( + selection_1[ax].stop - selection_1[ax].start for ax in axes + ) + + img_sel_1 = img[selection_1] + + assert img_sel_1.shape == expected_selection_shape, \ + (f"Expected selection {selection_1} to have shape " + f"{expected_selection_shape}, got {img_sel_1.shape} instead") + + @pytest.mark.parametrize("axes, roi, expected_size", [ (None, None, (16, 16, 3)), (None, slice(None), (16, 16, 3)), diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 24c97c7..2b14ed0 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -121,12 +121,38 @@ def test_PatchSampler_correct_patch_size(patch_size, spatial_axes, expected_patch_size): patch_sampler = zds.PatchSampler(patch_size=patch_size, spatial_axes=spatial_axes) - + assert patch_sampler._patch_size == expected_patch_size, \ (f"Expected `patch_size` to be a dictionary as {expected_patch_size}, " f"got {patch_sampler._patch_size} instead.") +@pytest.mark.parametrize("stride, spatial_axes, expected_stride", [ + (512, "X", dict(X=512)), + ((128, 64), "XY", dict(X=128, Y=64)), +]) +def test_PatchSampler_correct_stride(stride, spatial_axes, expected_stride): + patch_sampler = zds.PatchSampler(patch_size=512, stride=stride, + spatial_axes=spatial_axes) + + assert patch_sampler._stride == expected_stride, \ + (f"Expected `stride` to be a dictionary as {expected_stride}, " + f"got {patch_sampler._stride} instead.") + + +@pytest.mark.parametrize("pad, spatial_axes, expected_pad", [ + (512, "X", dict(X=512)), + ((128, 64), "XY", dict(X=128, Y=64)), +]) +def test_PatchSampler_correct_pad(pad, spatial_axes, expected_pad): + patch_sampler = zds.PatchSampler(patch_size=512, pad=pad, + spatial_axes=spatial_axes) + + assert patch_sampler._pad == expected_pad, \ + (f"Expected `pad` to be a dictionary as {expected_pad}, " + f"got {patch_sampler._pad} instead.") + + @pytest.mark.parametrize("patch_size, spatial_axes", [ ((512, 128), "X"), ((128, ), "XY"), @@ -138,6 +164,30 @@ def test_PatchSampler_incorrect_patch_size(patch_size, spatial_axes): spatial_axes=spatial_axes) +@pytest.mark.parametrize("stride, spatial_axes", [ + ((512, 128), "X"), + ((128, ), "XY"), + ("stride", "ZYX"), +]) +def test_PatchSampler_incorrect_stride(stride, spatial_axes): + with pytest.raises(ValueError): + patch_sampler = zds.PatchSampler(patch_size=512, + stride=stride, + spatial_axes=spatial_axes) + + +@pytest.mark.parametrize("pad, spatial_axes", [ + ((512, 128), "X"), + ((128, ), "XY"), + ("pad", "ZYX"), +]) +def test_PatchSampler_incorrect_pad(pad, spatial_axes): + with pytest.raises(ValueError): + patch_sampler = zds.PatchSampler(patch_size=512, + pad=pad, + spatial_axes=spatial_axes) + + @pytest.mark.parametrize("patch_size, image_collection", [ (32, IMAGE_SPECS[10]) ], indirect=["image_collection"]) @@ -146,18 +196,18 @@ def test_PatchSampler_chunk_generation(patch_size, image_collection): chunks_toplefts = patch_sampler.compute_chunks(image_collection) - chunk_size = dict( - (ax, cs) + chunk_size = { + ax: cs for ax, cs in zip(image_collection.collection["images"].axes, image_collection.collection["images"].chunk_size) - ) + } - scaled_chunk_size = dict( - (ax, int(cs * image_collection.collection["masks"].scale[ax])) + scaled_chunk_size = { + ax: int(cs * image_collection.collection["masks"].scale[ax]) for ax, cs in zip(image_collection.collection["images"].axes, image_collection.collection["images"].chunk_size) if ax in image_collection.collection["masks"].axes - ) + } scaled_mask = transform.downscale_local_mean( image_collection.collection["masks"][:], @@ -194,10 +244,92 @@ def test_PatchSampler(patch_size, image_collection): chunk_tlbr=chunks_toplefts[0] ) - scaled_patch_size = dict( - (ax, int(patch_size * scl)) + scaled_patch_size = { + ax: int(patch_size * scl) for ax, scl in image_collection.collection["masks"].scale.items() + } + + scaled_mask = transform.downscale_local_mean( + image_collection.collection["masks"][chunks_toplefts[0]], + factors=(scaled_patch_size["Y"], scaled_patch_size["X"]) ) + expected_patches_toplefts = np.nonzero(scaled_mask) + + expected_patches_toplefts = [ + dict( + [("Z", slice(0, 1, None))] + + [ + (ax, slice(tl * patch_size, (tl + 1) * patch_size)) + for ax, tl in zip("YX", tls) + ] + ) + for tls in zip(*expected_patches_toplefts) + ] + + assert all(map(operator.eq, patches_toplefts, expected_patches_toplefts)),\ + (f"Expected patches to be {expected_patches_toplefts[:3]}, got " + f"{patches_toplefts[:3]} instead.") + + +@pytest.mark.parametrize("patch_size, stride, image_collection", [ + (32, 32, IMAGE_SPECS[10]), + (32, 16, IMAGE_SPECS[10]), + (32, 64, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_stride(patch_size, stride, image_collection): + patch_sampler = zds.PatchSampler(patch_size, stride=stride) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + scaled_patch_size = { + ax: int(stride * scl) + for ax, scl in image_collection.collection["masks"].scale.items() + } + + scaled_mask = transform.downscale_local_mean( + image_collection.collection["masks"][chunks_toplefts[0]], + factors=(scaled_patch_size["Y"], scaled_patch_size["X"]) + ) + expected_patches_toplefts = np.nonzero(scaled_mask) + + expected_patches_toplefts = [ + dict( + [("Z", slice(0, 1, None))] + + [ + (ax, slice(tl * stride, tl * stride + patch_size)) + for ax, tl in zip("YX", tls) + ] + ) + for tls in zip(*expected_patches_toplefts) + ] + assert all(map(operator.eq, patches_toplefts, expected_patches_toplefts)),\ + (f"Expected patches to be {expected_patches_toplefts[:3]}, got " + f"{patches_toplefts[:3]} instead.") + + +@pytest.mark.parametrize("patch_size, pad, image_collection", [ + (32, 0, IMAGE_SPECS[10]), + (32, 2, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_pad(patch_size, pad, image_collection): + patch_sampler = zds.PatchSampler(patch_size, pad=pad) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + scaled_patch_size = { + ax: int(patch_size * scl) + for ax, scl in image_collection.collection["masks"].scale.items() + } scaled_mask = transform.downscale_local_mean( image_collection.collection["masks"][chunks_toplefts[0]], @@ -205,12 +337,12 @@ def test_PatchSampler(patch_size, image_collection): ) expected_patches_toplefts = np.nonzero(scaled_mask) + # TODO: Change expected patches toplefts for strided ones expected_patches_toplefts = [ dict( [("Z", slice(0, 1, None))] + [ - (ax, slice(int(tl * patch_size), - int(math.ceil((tl + 1) * patch_size)))) + (ax, slice(tl * patch_size - pad, (tl + 1) * patch_size + pad)) for ax, tl in zip("YX", tls) ] ) @@ -222,6 +354,33 @@ def test_PatchSampler(patch_size, image_collection): f"{patches_toplefts[:3]} instead.") +@pytest.mark.parametrize("patch_size, allow_incomplete_patches," + "image_collection", [ + (1024, True, IMAGE_SPECS[10]), + (1024, False, IMAGE_SPECS[10]), +], indirect=["image_collection"]) +def test_PatchSampler_incomplete_patches(patch_size, allow_incomplete_patches, + image_collection): + patch_sampler = zds.PatchSampler( + patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + + chunks_toplefts = patch_sampler.compute_chunks(image_collection) + + patches_toplefts = patch_sampler.compute_patches( + image_collection, + chunk_tlbr=chunks_toplefts[0] + ) + + expected_num_patches = 1 if allow_incomplete_patches else 0 + + assert len(patches_toplefts) == expected_num_patches,\ + (f"Expected to have {expected_num_patches}, when " + f"`allow_incomplete_patches` is {allow_incomplete_patches} " + f"got {len(patches_toplefts)} instead.") + + @pytest.mark.parametrize("patch_size, axes, resample, allow_overlap," "image_collection", [ (dict(X=32, Y=32, Z=1), "XYZ", True, True, IMAGE_SPECS[10]), @@ -247,15 +406,6 @@ def test_BlueNoisePatchSampler(patch_size, axes, resample, allow_overlap, (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " f"{len(patches_toplefts)} instead.") - patches_toplefts = patch_sampler.compute_patches( - image_collection, - chunk_tlbr=chunks_toplefts[-1] - ) - - assert len(patches_toplefts) == len(patch_sampler._base_chunk_tls), \ - (f"Expected {len(patch_sampler._base_chunk_tls)} patches, got " - f"{len(patches_toplefts)} instead.") - @pytest.mark.parametrize("image_collection_mask_not2scale", [ IMAGE_SPECS[10] @@ -276,14 +426,19 @@ def test_BlueNoisePatchSampler_mask_not2scale(image_collection_mask_not2scale): chunk_tlbr=chunks_toplefts[0] ) - assert len(patches_toplefts) == 0, \ + # Samples can be retrieved from chunks that are not multiple of the patch + # size. The ZarrDataset class should handle these cases, either by droping + # these patches, or by adding padding when allowed by the user. + assert len(patches_toplefts) == 1, \ (f"Expected 0 patches, got {len(patches_toplefts)} instead.") -@pytest.mark.parametrize("patch_size, image_collection, specs", [ - (512, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]) +@pytest.mark.parametrize("patch_size, stride, image_collection, specs", [ + (512, 512, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]), + (512, 256, MASKABLE_IMAGE_SPECS[0], MASKABLE_IMAGE_SPECS[0]) ], indirect=["image_collection"]) -def test_unique_sampling_PatchSampler(patch_size, image_collection, specs): +def test_unique_sampling_PatchSampler(patch_size, stride, image_collection, + specs): from skimage import color, filters, morphology import zarr @@ -304,7 +459,8 @@ def test_unique_sampling_PatchSampler(patch_size, image_collection, specs): mode="masks") image_collection.reset_scales() - patch_sampler = zds.PatchSampler(patch_size, min_area=1/16 ** 2) + patch_sampler = zds.PatchSampler(patch_size, stride=stride, + min_area=1/16 ** 2) chunks_toplefts = patch_sampler.compute_chunks(image_collection) diff --git a/tests/test_zarrdataset.py b/tests/test_zarrdataset.py index 7112f63..17b0b74 100644 --- a/tests/test_zarrdataset.py +++ b/tests/test_zarrdataset.py @@ -116,8 +116,12 @@ def image_dataset_specs(request): @pytest.fixture(scope="function") def patch_sampler_specs(request): - patch_sampler = zds.PatchSampler(patch_size=request.param) - return patch_sampler, request.param + patch_size, allow_incomplete_patches = request.param + patch_sampler = zds.PatchSampler( + patch_size=patch_size, + allow_incomplete_patches=allow_incomplete_patches + ) + return patch_sampler, patch_size, allow_incomplete_patches @pytest.mark.parametrize("image_dataset_specs", [ @@ -142,7 +146,7 @@ def test_compatibility_no_tqdm(image_dataset_specs): try: next(iter(dataset)) - + except Exception as e: raise AssertionError(f"No exceptions where expected, got {e} " f"instead.") @@ -304,9 +308,9 @@ def test_ZarrDataset(image_dataset_specs, shuffle, return_positions, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk", [ - (IMAGE_SPECS[10], 32, True, False), - (IMAGE_SPECS[10], 32, True, True), - (IMAGE_SPECS[10], 32, False, True), + (IMAGE_SPECS[10], (32, False), True, False), + (IMAGE_SPECS[10], (32, False), True, True), + (IMAGE_SPECS[10], (32, False), False, True), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -314,7 +318,7 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -409,33 +413,40 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs", [ - (IMAGE_SPECS[10], 1024), + (IMAGE_SPECS[10], (1024, True)), + (IMAGE_SPECS[10], (1024, False)), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_greater_patch_ZarrDataset(image_dataset_specs, patch_sampler_specs): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, - patch_sampler=patch_sampler, + patch_sampler=patch_sampler ) n_samples = 0 for _ in ds: n_samples += 1 - assert n_samples == 0, ("Expected zero samples since requested patch size" - f" is greater than the image size.") + if allow_incomplete_patches: + assert n_samples > 0, ("Expected at elast one sample when patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is True.") + else: + assert n_samples == 0, ("Expected zero samples since requested patch" + " size is greater than the image size, and" + " `allow_incomplete_patches` is False.") @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers", [ - (IMAGE_SPECS[10], 32, True, False, 2, 2), - ([IMAGE_SPECS[10]] * 4, 32, True, True, 2, 3), - ([IMAGE_SPECS[10]] * 2, 32, True, True, 2, 3), + (IMAGE_SPECS[10], (32, False), True, False, 2, 2), + ([IMAGE_SPECS[10]] * 4, (32, False), True, True, 2, 3), + ([IMAGE_SPECS[10]] * 2, (32, False), True, True, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) @@ -446,7 +457,7 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, num_workers): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = zds.ZarrDataset( dataset_specs=dataset_specs, @@ -514,21 +525,21 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs, @pytest.mark.parametrize( "image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk," "batch_size, num_workers, repeat_dataset", [ - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 1), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 2), - (IMAGE_SPECS[10:12], 32, True, False, 2, 2, 3), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 1), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 2), + (IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 3), ], indirect=["image_dataset_specs", "patch_sampler_specs"] ) def test_multithread_chained_ZarrDataset(image_dataset_specs, - patch_sampler_specs, - shuffle, - draw_same_chunk, - batch_size, - num_workers, - repeat_dataset): + patch_sampler_specs, + shuffle, + draw_same_chunk, + batch_size, + num_workers, + repeat_dataset): dataset_specs, specs = image_dataset_specs - patch_sampler, patch_size = patch_sampler_specs + patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs ds = [zds.ZarrDataset(dataset_specs=dataset_specs, shuffle=shuffle, diff --git a/zarrdataset/_imageloaders.py b/zarrdataset/_imageloaders.py index 5d92dc1..06f6b56 100644 --- a/zarrdataset/_imageloaders.py +++ b/zarrdataset/_imageloaders.py @@ -89,7 +89,7 @@ def image2array(arr_src: Union[str, zarr.Group, zarr.Array, np.ndarray], arr = zarr.array(data=arr_src, shape=arr_src.shape, chunks=arr_src.shape) return arr, None - + # Try to create a connection with the file, to determine if it is a remote # resource or local file. s3_obj = connect_s3(arr_src) @@ -98,7 +98,7 @@ def image2array(arr_src: Union[str, zarr.Group, zarr.Array, np.ndarray], # Try to open the input file with tifffile (if installed). try: if (data_group is None - or (isinstance(data_group, str) and not len(data_group))): + or (isinstance(data_group, str) and not len(data_group))): tiff_args = dict( key=None, level=None, @@ -192,9 +192,9 @@ class ImageBase(object): _image_func = None def __init__(self, shape: Iterable[int], - chunk_size: Union[Iterable[int], None]=None, - source_axes: str="", - mode: str=""): + chunk_size: Union[Iterable[int], None] = None, + source_axes: str = "", + mode: str = ""): if chunk_size is None: chunk_size = shape @@ -204,6 +204,7 @@ def __init__(self, shape: Iterable[int], self.arr = zarr.ones(shape=shape, dtype=bool, chunks=chunk_size) self.roi = tuple([slice(None)] * len(source_axes)) self.mode = mode + self._chunk_size = chunk_size def _iscached(self, coords): @@ -224,17 +225,30 @@ def _cache_chunk(self, index): if not self._iscached(index): self._cached_coords = tuple( map(lambda i, chk, s: - slice(chk * int(i.start / chk) + slice(max(0, chk * int(i.start / chk)) if i.start is not None else 0, min(s, chk * int(math.ceil(i.stop / chk))) - if i.stop is not None else None, - None), + if i.stop is not None else s), index, self.arr.chunks, self.arr.shape) ) + + padding = tuple( + (cc.start - i.start if i.start is not None and i.start < 0 else 0, + i.stop - cc.stop if i.stop is not None and i.stop > s else 0) + for cc, i, s in zip(self._cached_coords, index, self.arr.shape) + ) + self._cache = self.arr[self._cached_coords] + if any([any(pad) for pad in padding]): + self._cache = np.pad(self._cache, padding) + self._cached_coords = tuple( + slice(cc.start - p_low, cc.stop + p_high) + for (p_low, p_high), cc in zip(padding, self._cached_coords) + ) + cached_index = tuple( map(lambda cache, i: slice((i.start - cache.start) if i.start is not None else 0, @@ -257,20 +271,14 @@ def __getitem__(self, index : Union[slice, tuple, dict]) -> np.ndarray: if not isinstance(index, dict): # Arrange the indices requested using the reference image axes # ordering. - index = dict( - ((ax, sel) - for ax, sel in zip(spatial_reference_axes, index)) - ) + index = {ax: sel for ax, sel in zip(spatial_reference_axes, index)} mode_index, _ = select_axes(self.axes, index) mode_scales = tuple(self.scale[ax] for ax in self.axes) mode_index = scale_coords(mode_index, mode_scales) - mode_index = dict( - ((ax, sel) - for ax, sel in zip(self.axes, mode_index)) - ) + mode_index = {ax: sel for ax, sel in zip(self.axes, mode_index)} # Locate the mode_index within the ROI: roi_mode_index = translate2roi(mode_index, self.roi, self.source_axes, @@ -430,7 +438,7 @@ def __init__(self, filename: str, source_axes: str, parsed_roi = roi elif isinstance(roi, slice): if (len(source_axes) > 1 - and not (roi.start is None and roi.stop is None)): + and not (roi.start is None and roi.stop is None)): raise ValueError(f"ROIs must specify a slice per axes. " f"Expected {len(source_axes)} slices, got " f"only {roi}") @@ -440,11 +448,10 @@ def __init__(self, filename: str, source_axes: str, raise ValueError(f"Incorrect ROI format, expected a list of " f"slices, or a parsable string, got {roi}") - roi_slices = list( - map(lambda r: - slice(r.start if r.start is not None else 0, r.stop, None), - parsed_roi) - ) + roi_slices = [ + slice(r.start if r.start is not None else 0, r.stop, None) + for r in parsed_roi + ] (self.arr, self._store) = image2array(filename, data_group=data_group, @@ -512,11 +519,11 @@ def __init__(self, collection_args : dict, self.spatial_axes = spatial_axes - self.collection = dict(( - (mode, ImageLoader(spatial_axes=spatial_axes, mode=mode, - **mode_args)) + self.collection = { + mode: ImageLoader(spatial_axes=spatial_axes, mode=mode, + **mode_args) for mode, mode_args in collection_args.items() - )) + } self._generate_mask() self.reset_scales() @@ -574,9 +581,7 @@ def reset_scales(self) -> None: img.rescale(spatial_reference_shape, spatial_reference_axes) def __getitem__(self, index): - collection_set = dict( - (mode, img[index]) - for mode, img in self.collection.items() - ) + collection_set = {mode: img[index] + for mode, img in self.collection.items()} return collection_set diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index bff3d6a..52df198 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -1,8 +1,6 @@ -from typing import Iterable, Union, Tuple +from typing import Iterable, Union, Tuple, List import math import numpy as np -from itertools import repeat -from functools import reduce import operator import poisson_disc @@ -23,11 +21,19 @@ class PatchSampler(object): patches (hyper-cuboids) are supported by now. If a single int is passed, that size is used for all dimensions. If an iterable (list, tuple) is passed, each value will be assigned to the corresponding axes - in `spatial_axes`, the size of `patch_size` must match the lenght of + in `spatial_axes`, the size of `patch_size` must match the lenght of `spatial_axes'. If a dict is passed, this should have at least the size - of the patch of the axes listed in `spatial_axes`. Use the same - convention as how Zarr structure array chunks in order to handle patch + of the patch of the axes listed in `spatial_axes`. Use the same + convention as how Zarr structure array chunks in order to handle patch shapes and channels correctly. + stride : Union[int, Iterable[int], dict, None] + Distance in pixels of the movement of the sampling sliding window. + If `stride` is less than `patch_size` for an axis, patches will have an + overlap between them. This is usuful in inference mode for avoiding + edge artifacts. If None is passed, the `patch_size` will be used as + `stride`. + pad : Union[int, Iterable[int], dict, None] + Padding in pixels added to the extracted patch at each specificed axis. min_area : Union[int, float] Minimum patch area covered by the mask to consider it samplable. A number in range [0, 1) will be used as percentage of the patch size. A @@ -35,13 +41,20 @@ class PatchSampler(object): covered by the mask. spatial_axes : str The spatial axes from where patches can be extracted. + allow_incomplete_patches : bool + Allow to retrieve patches that are smaller than the patch size. This is + the case of samples at the edge of the image that are usually smaller + than the specified patch size. """ def __init__(self, patch_size: Union[int, Iterable[int], dict], + stride: Union[int, Iterable[int], dict, None] = None, + pad: Union[int, Iterable[int], dict, None] = None, min_area: Union[int, float] = 1, - spatial_axes: str ="ZYX"): + spatial_axes: str = "ZYX", + allow_incomplete_patches: bool = False): # The maximum chunk sizes are used to generate a reference sampling # position array used fo every sampled chunk. - self._max_chunk_size = dict((ax, 0) for ax in spatial_axes) + self._max_chunk_size = {ax: 0 for ax in spatial_axes} if isinstance(patch_size, (list, tuple)): if len(patch_size) != len(spatial_axes): @@ -49,11 +62,10 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], f"number of axes in `spatial_axes`, got " f"{len(patch_size)} for {spatial_axes}") - patch_size = dict((ax, ps) - for ax, ps in zip(spatial_axes, patch_size)) + patch_size = {ax: ps for ax, ps in zip(spatial_axes, patch_size)} elif isinstance(patch_size, int): - patch_size = dict((ax, patch_size) for ax in spatial_axes) + patch_size = {ax: patch_size for ax in spatial_axes} elif not isinstance(patch_size, dict): raise ValueError(f"Patch size must be a dictionary specifying the" @@ -62,200 +74,274 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], f" or an integer for a cubic patch. Received " f"{patch_size} of type {type(patch_size)}") + if isinstance(stride, (list, tuple)): + if len(stride) != len(spatial_axes): + raise ValueError(f"The size of `stride` must match the " + f"number of axes in `spatial_axes`, got " + f"{len(stride)} for {spatial_axes}") + + stride = {ax: st for ax, st in zip(spatial_axes, stride)} + + elif isinstance(stride, int): + stride = {ax: stride for ax in spatial_axes} + + elif stride is None: + stride = patch_size + + elif not isinstance(stride, dict): + raise ValueError(f"Stride size must be a dictionary specifying the" + f" stride step size of each axes, an iterable (" + f"list, tuple) with the same order as the spatial" + f" axes, or an integer for a cubic patch. " + f"Received {stride} of type {type(stride)}") + + if pad is None: + pad = 0 + + if isinstance(pad, (list, tuple)): + if len(pad) != len(spatial_axes): + raise ValueError(f"The size of `pad` must match the " + f"number of axes in `spatial_axes`, got " + f"{len(pad)} for {spatial_axes}") + + pad = {ax: st for ax, st in zip(spatial_axes, pad)} + + elif isinstance(pad, int): + pad = {ax: pad for ax in spatial_axes} + + elif not isinstance(pad, dict): + raise ValueError(f"Pad size must be a dictionary specifying the" + f" numer of pixels added to each axes, an " + f"iterable (list, tuple) with the same order as " + f"the spatial axes, or an integer for a cubic " + f"patch. Received {pad} of type {type(pad)}") + self.spatial_axes = spatial_axes - self._patch_size = dict( - (ax, patch_size.get(ax, 1)) - for ax in spatial_axes - ) + self._patch_size = {ax: patch_size.get(ax, 1) for ax in spatial_axes} + self._stride = {ax: stride.get(ax, 1) for ax in spatial_axes} + self._pad = {ax: pad.get(ax, 0) for ax in spatial_axes} self._min_area = min_area + self._allow_incomplete_patches = allow_incomplete_patches - def _compute_corners(self, non_zero_pos: tuple, axes: str) -> np.ndarray: - toplefts = np.stack(non_zero_pos).T - toplefts = toplefts.astype(np.float32) - + def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray + ) -> np.ndarray: toplefts_corners = [] - dim = len(axes) + dim = coordinates.shape[-1] factors = 2 ** np.arange(dim + 1) for d in range(2 ** dim): - corner_value = np.array((d % factors[1:]) // factors[:-1], - dtype=np.float32) + corner_value = np.array((d % factors[1:]) // factors[:-1], + dtype=np.float32) toplefts_corners.append( - toplefts + (1 - 1e-4) * corner_value + coordinates + scale * (1 - 1e-4) * corner_value ) corners = np.stack(toplefts_corners) return corners - def _compute_overlap(self, corners: np.ndarray, shape: np.ndarray, - ref_shape: np.ndarray) -> Tuple[np.ndarray, - np.ndarray]: - scaled_corners = corners * shape[None, None] - tls_scaled = scaled_corners / ref_shape[None, None] - tls_idx = np.floor(tls_scaled) + def _compute_reference_indices(self, reference_coordinates: np.ndarray + ) -> Tuple[List[np.ndarray], + List[Tuple[int]]]: + reference_per_axis = list(map( + lambda coords: np.append(np.full((1, ), fill_value=-float("inf")), + np.unique(coords)), + reference_coordinates.T + )) + + reference_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, ...] + * (coord_axis[..., None] >= ref_axis[None, ...]), + axis=-1), + reference_coordinates.T, + reference_per_axis + ) + reference_idx = np.stack(tuple(reference_idx), axis=-1) + reference_idx = [ + tuple(tls_coord - 1) + for tls_coord in reference_idx.reshape(-1, len(reference_per_axis)) + ] + + return reference_per_axis, reference_idx + + def _compute_overlap(self, corners_coordinates: np.ndarray, + reference_per_axis: np.ndarray) -> Tuple[np.ndarray, + np.ndarray]: + tls_idx = map( + lambda coord_axis, ref_axis: + np.argmax(ref_axis[None, None, ...] + * (coord_axis[..., None] >= ref_axis[None, None, ...]), + axis=-1), + np.moveaxis(corners_coordinates, -1, 0), + reference_per_axis + ) + tls_idx = np.stack(tuple(tls_idx), axis=-1) - corners_cut = np.maximum(tls_scaled[0], tls_idx[-1]) + tls_coordinates = map( + lambda tls_coord, ref_axis: ref_axis[tls_coord], + np.moveaxis(tls_idx, -1, 0), + reference_per_axis + ) + tls_coordinates = np.stack(tuple(tls_coordinates), axis=-1) - dist2cut = np.fabs(scaled_corners - corners_cut[None]) + corners_cut = np.maximum(corners_coordinates[0], tls_coordinates[-1]) + + dist2cut = np.fabs(corners_coordinates - corners_cut[None]) coverage = np.prod(dist2cut, axis=-1) - return coverage, tls_idx.astype(np.int64) + return coverage, tls_idx - 1 - def _compute_grid(self, chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): + image_size: dict, + allow_incomplete_patches: bool = False): + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + image_scale = np.array([image_size.get(ax, 1) / patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + round_fn = math.ceil if allow_incomplete_patches else math.floor + + image_blocks = [ + round_fn( + ( + min(image_size.get(ax, 1), + chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float("inf")) + - (chunk_tlbr[ax].start + if chunk_tlbr[ax].start is not None + else 0) + ) / patch_size.get(ax, 1)) + for ax in self.spatial_axes + ] - mask_relative_shape = np.array( - [1 / m_scl - for ax, m_scl in mask_scale.items() - if ax in self.spatial_axes - ], - dtype=np.float32 - ) + if min(image_blocks) == 0: + return [] - patch_shape = np.array( - [patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes - ], - dtype=np.float32 - ) + image_scale = np.array([patch_size.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + image_coordinates = np.array(list(np.ndindex(*image_blocks)), + dtype=np.float32) - # If the patch sizes are greater than the relative shape of the mask - # with respect to the input image, use the mask coordinates as - # reference to overlap the coordinates of the sampling patches. - # Otherwise, use the patches coordinates instead. - if all(map(operator.ge, patch_shape, mask_relative_shape)): - active_coordinates = np.nonzero(chunk_mask) - ref_axes = mask_axes - - ref_shape = patch_shape - shape = mask_relative_shape - - mask_is_greater = False - - patch_ratio = [ - image_size[ax] // ps - for ax, ps in zip(mask_axes, patch_shape.astype(np.int64)) - if ax in self.spatial_axes - ] + image_coordinates *= image_scale - if not all(patch_ratio): - return np.empty( - [0] * len(set(mask_axes).intersection(self.spatial_axes)), - dtype=np.int64 - ) + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) - else: - active_coordinates = np.meshgrid( - *[np.arange(image_size[ax] // ps) - for ax, ps in zip(mask_axes, patch_shape) - if ax in self.spatial_axes] - ) + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) + ) - active_coordinates = tuple( - coord_ax.flatten() - for coord_ax in active_coordinates - ) + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] - ref_axes = "".join([ - ax for ax in self.spatial_axes if ax in mask_axes - ]) + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) - ref_shape = mask_relative_shape - shape = patch_shape + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] - mask_is_greater = True + if all(map(operator.ge, image_scale, mask_scale)): + mask_corners = self._compute_corners(mask_coordinates, mask_scale) - corners = self._compute_corners(active_coordinates, axes=ref_axes) + (reference_per_axis, + reference_idx) =\ + self._compute_reference_indices(image_coordinates) - (coverage, - corners_idx)= self._compute_overlap(corners, shape, ref_shape) + (coverage, + corners_idx) = self._compute_overlap(mask_corners, + reference_per_axis) - if mask_is_greater: - # The mask ratio is greater than the patches size - mask_values = chunk_mask[tuple(corners_idx.T)].T - patches_coverage = coverage * mask_values + covered_indices = [ + reference_idx.index(tuple(idx)) + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ] - covered_tls = corners[0, ...].astype(np.int64) + patches_coverage = np.bincount(covered_indices, + weights=coverage.flatten(), + minlength=np.prod(image_blocks)) else: - # The mask ratio is less than the patches size - patch_coordinates = np.ravel_multi_index(tuple(corners_idx.T), - chunk_mask.shape) - patches_coverage = np.bincount(patch_coordinates.flatten(), - weights=coverage.flatten()) - patches_coverage = np.take(patches_coverage, patch_coordinates).T + image_corners = self._compute_corners(image_coordinates, + image_scale) - covered_tls = corners_idx[0, ...] + (reference_per_axis, + reference_idx) = self._compute_reference_indices(mask_coordinates) - patches_coverage = np.sum(patches_coverage, axis=0) + (coverage, + corners_idx) = self._compute_overlap(image_corners, + reference_per_axis) - # Compute minimum area covered by masked regions to sample a patch. - min_area = self._min_area - if min_area < 1: - min_area *= patch_shape.prod() + covered_indices = np.array([ + tuple(idx) in reference_idx + for idx in corners_idx.reshape(-1, len(self.spatial_axes)) + ]).reshape(coverage.shape) - minumum_covered_tls = covered_tls[patches_coverage > min_area] + patches_coverage = np.sum(covered_indices * coverage, axis=0) - if not mask_is_greater: - # Collapse to unique coordinates since there will be multiple - # instances of the same patch. - minumum_covered_tls = np.ravel_multi_index( - tuple(minumum_covered_tls.T), - patch_ratio, - mode="clip" - ) - - minumum_covered_tls = np.unique(minumum_covered_tls) - - minumum_covered_tls = np.unravel_index( - minumum_covered_tls, - patch_ratio - ) + min_area = self._min_area + if min_area < 1: + min_area *= np.prod(list(patch_size.values())) - minumum_covered_tls = np.stack(minumum_covered_tls).T + minimum_covered_tls = image_coordinates[patches_coverage > min_area] + minimum_covered_tls = minimum_covered_tls.astype(np.int64) - return minumum_covered_tls * patch_shape[None].astype(np.int64) + return minimum_covered_tls - def _compute_valid_toplefts(self, chunk_mask: np.ndarray, mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_size: dict): - return self._compute_grid(chunk_mask, mask_axes, mask_scale, - patch_size, - image_size) + **kwargs): + return self._compute_grid(chunk_tlbr, mask, patch_size, **kwargs) - def _compute_toplefts_slices(self, mask: ImageBase, image_shape: dict, - patch_size: dict, + def _compute_toplefts_slices(self, chunk_tlbr: dict, valid_mask_toplefts: np.ndarray, - chunk_tlbr: dict): - toplefts = [] - for tls in valid_mask_toplefts: - curr_tl = [] - - for ax in self.spatial_axes: - if ax in mask.axes: - tl = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)]) - br = ((chunk_tlbr[ax].start - if chunk_tlbr[ax].start is not None else 0) - + tls[mask.axes.index(ax)] + patch_size[ax]) - if br <= image_shape[ax]: - curr_tl.append((ax, slice(tl, br))) - else: - break - else: - curr_tl.append((ax, slice(0, 1))) - - else: - toplefts.append(dict(curr_tl)) + patch_size: dict, + pad: Union[dict, None] = None): + if pad is None: + pad = {ax: 0 for ax in self.spatial_axes} + + toplefts = [ + {ax: slice( + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + - pad[ax], + (chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None + else 0) + tls[self.spatial_axes.index(ax)] + patch_size[ax] + + pad[ax]) + for ax in self.spatial_axes + } + for tls in valid_mask_toplefts + ] return toplefts @@ -284,43 +370,39 @@ def compute_chunks(self, image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - spatial_chunk_sizes = dict( - (ax, chk) + # This computes a chunk size in terms of the patch size instead of the + # original array chunk size. + spatial_chunk_sizes = { + ax: (self._stride[ax] + * max(1, math.ceil(chk / self._stride[ax]))) for ax, chk in zip(image.axes, image.chunk_size) if ax in self.spatial_axes - ) + } - image_shape = dict(map(tuple, zip(image.axes, image.shape))) + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} - self._max_chunk_size = dict( - (ax, (min(max(self._max_chunk_size[ax], - spatial_chunk_sizes[ax], - self._patch_size[ax]), - image_shape[ax])) - if ax in image.axes else 1) + self._max_chunk_size = { + ax: (min(max(self._max_chunk_size[ax], + spatial_chunk_sizes[ax]), + image_size[ax])) + if ax in image.axes else 1 for ax in self.spatial_axes - ) - - chunk_tlbr = dict( - map(tuple, zip(self.spatial_axes, repeat(slice(None)))) - ) + } - chunk_mask = mask[chunk_tlbr] + chunk_tlbr = {ax: slice(None) for ax in self.spatial_axes} valid_mask_toplefts = self._compute_grid( - chunk_mask, - mask.axes, - mask.scale, + chunk_tlbr, + mask, self._max_chunk_size, - image_shape + image_size, + allow_incomplete_patches=True ) chunks_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._max_chunk_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr + patch_size=self._max_chunk_size ) return chunks_slices @@ -329,31 +411,36 @@ def compute_patches(self, image_collection: ImageCollection, chunk_tlbr: dict) -> Iterable[dict]: image = image_collection.collection[image_collection.reference_mode] mask = image_collection.collection[image_collection.mask_mode] - image_shape = dict(map(tuple, zip(image.axes, image.shape))) - chunk_size = dict( - (ax, - (ctb.stop if ctb.stop is not None else image_shape[ax]) - - (ctb.start if ctb.start is not None else 0) - ) - for ax, ctb in chunk_tlbr.items() - ) + image_size = {ax: s for ax, s in zip(image.axes, image.shape)} - chunk_mask = mask[chunk_tlbr] + stride = { + ax: self._stride.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes + } + + patch_size = { + ax: self._patch_size.get(ax, 1) if image_size.get(ax, 1) > 1 else 1 + for ax in self.spatial_axes + } + + pad = { + ax: self._pad.get(ax, 0) if image_size.get(ax, 1) > 1 else 0 + for ax in self.spatial_axes + } valid_mask_toplefts = self._compute_valid_toplefts( - chunk_mask, - mask.axes, - mask.scale, - self._patch_size, - chunk_size + chunk_tlbr, + mask, + stride, + image_size=image_size, + allow_incomplete_patches=self._allow_incomplete_patches ) patches_slices = self._compute_toplefts_slices( - mask, - image_shape=image_shape, - patch_size=self._patch_size, + chunk_tlbr, valid_mask_toplefts=valid_mask_toplefts, - chunk_tlbr=chunk_tlbr + patch_size=patch_size, + pad=pad ) return patches_slices @@ -383,11 +470,11 @@ def __init__(self, patch_size: Union[int, Iterable[int], dict], resample_positions=False, allow_overlap=False, **kwargs): - super(BlueNoisePatchSampler, self).__init__(patch_size) + super(BlueNoisePatchSampler, self).__init__(patch_size, **kwargs) self._base_chunk_tls = None self._resample_positions = resample_positions self._allow_overlap = allow_overlap - + def compute_sampling_positions(self, force=False) -> None: """Compute the sampling positions using blue-noise sampling. @@ -425,35 +512,62 @@ def compute_sampling_positions(self, force=False) -> None: self._base_chunk_tls = np.zeros((1, len(self.spatial_axes)), dtype=np.int64) - def _compute_valid_toplefts(self, - chunk_mask: np.ndarray, - mask_axes: str, - mask_scale: dict, + def _compute_valid_toplefts(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, - image_shape: dict): + **kwargs): self.compute_sampling_positions(force=self._resample_positions) # Filter sampling positions that does not contain any mask portion. - sampling_pos = np.hstack( - tuple( - self._base_chunk_tls[:, self.spatial_axes.index(ax), None] - if ax in self.spatial_axes else - np.zeros((len(self._base_chunk_tls), 1), dtype=np.float32) - for ax in mask_axes - ) - ) - spatial_patch_sizes = np.array([ - patch_size[ax] - for ax in mask_axes - if ax in self.spatial_axes + patch_size.get(ax, 1) for ax in self.spatial_axes ]) - mask_corners = self._compute_corners(np.nonzero(chunk_mask), - mask_axes) + mask_scale = np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_scale = 1 / np.array([mask.scale.get(ax, 1) + for ax in self.spatial_axes], + dtype=np.float32) + + mask_coordinates = list(np.nonzero(mask[:])) + for ax_i, ax in enumerate(self.spatial_axes): + if ax not in mask.axes: + mask_coordinates.insert( + ax_i, + np.zeros(mask_coordinates[0].size) + ) + + mask_coordinates = np.stack(mask_coordinates, dtype=np.float32).T + mask_coordinates *= mask_scale[None, ...] + + # Filter out mask coordinates outside the current selected chunk + chunk_tl_coordinates = np.array( + [chunk_tlbr[ax].start if chunk_tlbr[ax].start is not None else 0 + for ax in self.spatial_axes], + dtype=np.float32 + ) + chunk_br_coordinates = np.array( + [chunk_tlbr[ax].stop + if chunk_tlbr[ax].stop is not None + else float('inf') + for ax in self.spatial_axes], + dtype=np.float32 + ) + + in_chunk = np.all( + np.bitwise_and( + mask_coordinates > (chunk_tl_coordinates - mask_scale - 1e-4), + mask_coordinates < chunk_br_coordinates + 1e-4 + ), + axis=1 + ) + mask_coordinates = mask_coordinates[in_chunk] + + mask_corners = self._compute_corners(mask_coordinates, mask_scale) dist = (mask_corners[None, :, :, :] - - sampling_pos[:, None, None, :].astype(np.float32) + - self._base_chunk_tls[:, None, None, :].astype(np.float32) - spatial_patch_sizes[None, None, None, :] / 2) mask_samplable_pos, = np.nonzero( @@ -463,6 +577,6 @@ def _compute_valid_toplefts(self, ) ) - toplefts = sampling_pos[mask_samplable_pos] + toplefts = self._base_chunk_tls[mask_samplable_pos] return toplefts diff --git a/zarrdataset/_zarrdataset.py b/zarrdataset/_zarrdataset.py index 34a1082..2b6cd4c 100644 --- a/zarrdataset/_zarrdataset.py +++ b/zarrdataset/_zarrdataset.py @@ -509,7 +509,7 @@ def _initialize(self, force=False): modes = self._collections.keys() for collection in zip(*self._collections.values()): - collection = dict([(m, c) for m, c in zip(modes, collection)]) + collection = {m: c for m, c in zip(modes, collection)} for mode in collection.keys(): collection[mode]["zarr_store"] = self._zarr_store[mode] collection[mode]["image_func"] = self._image_loader_func[mode] @@ -527,8 +527,8 @@ def _initialize(self, force=False): toplefts.append(self._patch_sampler.compute_chunks(curr_img)) else: toplefts.append([ - dict((ax, slice(None)) - for ax in curr_img.collection[self._ref_mod].axes) + {ax: slice(None) + for ax in curr_img.collection[self._ref_mod].axes} ] )