Skip to content

Commit

Permalink
Standardized patch sampling method to handle smaller and bigger mask …
Browse files Browse the repository at this point in the history
…scales than image scale
  • Loading branch information
fercer committed May 7, 2024
1 parent aaa51f6 commit 7885749
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 252 deletions.
177 changes: 177 additions & 0 deletions docs/source/examples/advanced_example_pytorch_inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
```python
import zarrdataset as zds

import torch
from torch.utils.data import DataLoader
```


```python
# These are images from the Image Data Resource (IDR)
# https://idr.openmicroscopy.org/ that are publicly available and were
# converted to the OME-NGFF (Zarr) format by the OME group. More examples
# can be found at Public OME-Zarr data (Nov. 2020)
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html

filenames = [
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr"
]
```


```python
import random
import numpy as np

# For reproducibility
np.random.seed(478963)
torch.manual_seed(478964)
random.seed(478965)
```

## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)

Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result.
Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size.


```python
patch_size = dict(Y=128, X=128)
pad = dict(Y=16, X=16)
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True)
```

Create a dataset from the list of filenames. All those files should be stored within their respective group "0".

Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly


```python
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="4",
source_axes="TCZYX",
axes="YXC",
roi="0,0,0,0,0:1,-1,1,-1,-1"
)

my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
return_positions=True)
```


```python
my_dataset
```




ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: []
Using images modality as reference.
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.



Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.


```python
import torchvision

img_preprocessing = torchvision.transforms.Compose([
zds.ToDtype(dtype=np.float32),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(127, 255)
])

my_dataset.add_transform("images", img_preprocessing)
```


```python
my_dataset
```




ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: [('images',)]
Using images modality as reference.
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.



## Create a DataLoader from the dataset object

ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.


```python
my_dataloader = DataLoader(my_dataset, num_workers=0)
```


```python
import dask.array as da
import numpy as np
import zarr

z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r")

H = z_arr.shape[-2]
W = z_arr.shape[-1]

pad_H = (128 - H) % 128
pad_W = (128 - W) % 128
z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128))
z_prediction
```




<zarr.core.Array (1152, 1408) float32>



Set up a simple model for illustration purpose


```python
model = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1),
torch.nn.ReLU()
)
```


```python
for i, (pos, sample) in enumerate(my_dataloader):
pred_pos = (
slice(pos[0, 0, 0].item() + 16,
pos[0, 0, 1].item() - 16),
slice(pos[0, 1, 0].item() + 16,
pos[0, 1, 1].item() - 16)
)
pred = model(sample)
z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16]
```

## Visualize the result


```python
import matplotlib.pyplot as plt

plt.subplot(2, 1, 1)
plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1))
plt.subplot(2, 1, 2)
plt.imshow(z_prediction)
plt.show()
```
36 changes: 27 additions & 9 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,33 @@ def test_PatchSampler_pad(patch_size, pad, image_collection):
f"{patches_toplefts[:3]} instead.")


@pytest.mark.parametrize("patch_size, allow_incomplete_patches,"
"image_collection", [
(1024, True, IMAGE_SPECS[10]),
(1024, False, IMAGE_SPECS[10]),
], indirect=["image_collection"])
def test_PatchSampler_incomplete_patches(patch_size, allow_incomplete_patches,
image_collection):
patch_sampler = zds.PatchSampler(
patch_size,
allow_incomplete_patches=allow_incomplete_patches
)

chunks_toplefts = patch_sampler.compute_chunks(image_collection)

patches_toplefts = patch_sampler.compute_patches(
image_collection,
chunk_tlbr=chunks_toplefts[0]
)

expected_num_patches = 1 if allow_incomplete_patches else 0

assert len(patches_toplefts) == expected_num_patches,\
(f"Expected to have {expected_num_patches}, when "
f"`allow_incomplete_patches` is {allow_incomplete_patches} "
f"got {len(patches_toplefts)} instead.")


@pytest.mark.parametrize("patch_size, axes, resample, allow_overlap,"
"image_collection", [
(dict(X=32, Y=32, Z=1), "XYZ", True, True, IMAGE_SPECS[10]),
Expand All @@ -379,15 +406,6 @@ def test_BlueNoisePatchSampler(patch_size, axes, resample, allow_overlap,
(f"Expected {len(patch_sampler._base_chunk_tls)} patches, got "
f"{len(patches_toplefts)} instead.")

patches_toplefts = patch_sampler.compute_patches(
image_collection,
chunk_tlbr=chunks_toplefts[-1]
)

assert len(patches_toplefts) == len(patch_sampler._base_chunk_tls), \
(f"Expected {len(patch_sampler._base_chunk_tls)} patches, got "
f"{len(patches_toplefts)} instead.")


@pytest.mark.parametrize("image_collection_mask_not2scale", [
IMAGE_SPECS[10]
Expand Down
63 changes: 37 additions & 26 deletions tests/test_zarrdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ def image_dataset_specs(request):

@pytest.fixture(scope="function")
def patch_sampler_specs(request):
patch_sampler = zds.PatchSampler(patch_size=request.param)
return patch_sampler, request.param
patch_size, allow_incomplete_patches = request.param
patch_sampler = zds.PatchSampler(
patch_size=patch_size,
allow_incomplete_patches=allow_incomplete_patches
)
return patch_sampler, patch_size, allow_incomplete_patches


@pytest.mark.parametrize("image_dataset_specs", [
Expand All @@ -142,7 +146,7 @@ def test_compatibility_no_tqdm(image_dataset_specs):

try:
next(iter(dataset))

except Exception as e:
raise AssertionError(f"No exceptions where expected, got {e} "
f"instead.")
Expand Down Expand Up @@ -304,17 +308,17 @@ def test_ZarrDataset(image_dataset_specs, shuffle, return_positions,

@pytest.mark.parametrize(
"image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk", [
(IMAGE_SPECS[10], 32, True, False),
(IMAGE_SPECS[10], 32, True, True),
(IMAGE_SPECS[10], 32, False, True),
(IMAGE_SPECS[10], (32, False), True, False),
(IMAGE_SPECS[10], (32, False), True, True),
(IMAGE_SPECS[10], (32, False), False, True),
],
indirect=["image_dataset_specs", "patch_sampler_specs"]
)
def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs,
shuffle,
draw_same_chunk):
dataset_specs, specs = image_dataset_specs
patch_sampler, patch_size = patch_sampler_specs
patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs

ds = zds.ZarrDataset(
dataset_specs=dataset_specs,
Expand Down Expand Up @@ -409,33 +413,40 @@ def test_patched_ZarrDataset(image_dataset_specs, patch_sampler_specs,

@pytest.mark.parametrize(
"image_dataset_specs, patch_sampler_specs", [
(IMAGE_SPECS[10], 1024),
(IMAGE_SPECS[10], (1024, True)),
(IMAGE_SPECS[10], (1024, False)),
],
indirect=["image_dataset_specs", "patch_sampler_specs"]
)
def test_greater_patch_ZarrDataset(image_dataset_specs, patch_sampler_specs):
dataset_specs, specs = image_dataset_specs
patch_sampler, patch_size = patch_sampler_specs
patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs

ds = zds.ZarrDataset(
dataset_specs=dataset_specs,
patch_sampler=patch_sampler,
patch_sampler=patch_sampler
)

n_samples = 0
for _ in ds:
n_samples += 1

assert n_samples == 0, ("Expected zero samples since requested patch size"
f" is greater than the image size.")
if allow_incomplete_patches:
assert n_samples > 0, ("Expected at elast one sample when patch"
" size is greater than the image size, and"
" `allow_incomplete_patches` is True.")
else:
assert n_samples == 0, ("Expected zero samples since requested patch"
" size is greater than the image size, and"
" `allow_incomplete_patches` is False.")


@pytest.mark.parametrize(
"image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk,"
"batch_size, num_workers", [
(IMAGE_SPECS[10], 32, True, False, 2, 2),
([IMAGE_SPECS[10]] * 4, 32, True, True, 2, 3),
([IMAGE_SPECS[10]] * 2, 32, True, True, 2, 3),
(IMAGE_SPECS[10], (32, False), True, False, 2, 2),
([IMAGE_SPECS[10]] * 4, (32, False), True, True, 2, 3),
([IMAGE_SPECS[10]] * 2, (32, False), True, True, 2, 3),
],
indirect=["image_dataset_specs", "patch_sampler_specs"]
)
Expand All @@ -446,7 +457,7 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs,
num_workers):
dataset_specs, specs = image_dataset_specs

patch_sampler, patch_size = patch_sampler_specs
patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs

ds = zds.ZarrDataset(
dataset_specs=dataset_specs,
Expand Down Expand Up @@ -514,21 +525,21 @@ def test_multithread_ZarrDataset(image_dataset_specs, patch_sampler_specs,
@pytest.mark.parametrize(
"image_dataset_specs, patch_sampler_specs, shuffle, draw_same_chunk,"
"batch_size, num_workers, repeat_dataset", [
(IMAGE_SPECS[10:12], 32, True, False, 2, 2, 1),
(IMAGE_SPECS[10:12], 32, True, False, 2, 2, 2),
(IMAGE_SPECS[10:12], 32, True, False, 2, 2, 3),
(IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 1),
(IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 2),
(IMAGE_SPECS[10:12], (32, False), True, False, 2, 2, 3),
],
indirect=["image_dataset_specs", "patch_sampler_specs"]
)
def test_multithread_chained_ZarrDataset(image_dataset_specs,
patch_sampler_specs,
shuffle,
draw_same_chunk,
batch_size,
num_workers,
repeat_dataset):
patch_sampler_specs,
shuffle,
draw_same_chunk,
batch_size,
num_workers,
repeat_dataset):
dataset_specs, specs = image_dataset_specs
patch_sampler, patch_size = patch_sampler_specs
patch_sampler, patch_size, allow_incomplete_patches = patch_sampler_specs

ds = [zds.ZarrDataset(dataset_specs=dataset_specs,
shuffle=shuffle,
Expand Down
Loading

0 comments on commit 7885749

Please sign in to comment.