Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use patch_size instead of chunk_size as base shape for sampling #4

Merged
merged 9 commits into from
May 7, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
example.py

# Translations
*.mo
Expand Down
194 changes: 194 additions & 0 deletions docs/source/examples/advanced_example_pytorch_inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
---
jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
execution:
timeout: 120
---

# Integration of ZarrDataset with PyTorch's DataLoader for inference (Advanced)

```python
import zarrdataset as zds

import torch
from torch.utils.data import DataLoader
```


```python
# These are images from the Image Data Resource (IDR)
# https://idr.openmicroscopy.org/ that are publicly available and were
# converted to the OME-NGFF (Zarr) format by the OME group. More examples
# can be found at Public OME-Zarr data (Nov. 2020)
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html

filenames = [
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr"
]
```


```python
import random
import numpy as np

# For reproducibility
np.random.seed(478963)
torch.manual_seed(478964)
random.seed(478965)
```

## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI)

Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result.
Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size by setting `allow_incomplete_patches=True`.


```python
patch_size = dict(Y=128, X=128)
pad = dict(Y=16, X=16)
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True)
```

Create a dataset from the list of filenames. All those files should be stored within their respective group "0".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a typo here for group "0", should it be "4" in this example?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing this @ClementCaporal! I considered this change and added it to a recent PR #8 that addresses an incorrect sampling of masked regions.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh Nice!
I was starting to use masked regions on friday and started noticing strange behavior so I just have to pull now thanks to you!

Have a good week,

Clément


Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly


```python
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="4",
source_axes="TCZYX",
axes="YXC",
roi="0,0,0,0,0:1,-1,1,-1,-1"
)

my_dataset = zds.ZarrDataset(image_specs,
patch_sampler=patch_sampler,
return_positions=True)
```


```python
my_dataset
```




ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: []
Using images modality as reference.
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.



Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32.


```python
import torchvision

img_preprocessing = torchvision.transforms.Compose([
zds.ToDtype(dtype=np.float32),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(127, 255)
])

my_dataset.add_transform("images", img_preprocessing)
```


```python
my_dataset
```




ZarrDataset (PyTorch support:True, tqdm support :True)
Modalities: images
Transforms order: [('images',)]
Using images modality as reference.
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}.



## Create a DataLoader from the dataset object

ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module.


```python
my_dataloader = DataLoader(my_dataset, num_workers=0)
```


```python
import dask.array as da
import numpy as np
import zarr

z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r")

H = z_arr.shape[-2]
W = z_arr.shape[-1]

pad_H = (128 - H) % 128
pad_W = (128 - W) % 128
z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128))
z_prediction
```




<zarr.core.Array (1152, 1408) float32>



Set up a simple model for illustration purpose


```python
model = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1),
torch.nn.ReLU()
)
```


```python
for i, (pos, sample) in enumerate(my_dataloader):
pred_pos = (
slice(pos[0, 0, 0].item() + 16,
pos[0, 0, 1].item() - 16),
slice(pos[0, 1, 0].item() + 16,
pos[0, 1, 1].item() - 16)
)
pred = model(sample)
z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16]
```

## Visualize the result


```python
import matplotlib.pyplot as plt

plt.subplot(2, 1, 1)
plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1))
plt.subplot(2, 1, 2)
plt.imshow(z_prediction)
plt.show()
```
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Welcome to ZarrDataset's documentation!

examples/advanced_example_pytorch

examples/advanced_example_pytorch_inference


REFERENCE
=========
Expand Down
23 changes: 23 additions & 0 deletions tests/test_imageloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ def test_ImageBase_slicing():
f"{expected_selection_shape}, got {img_sel_2.shape} instead")


def test_ImageBase_padding():
shape = (16, 16, 3)
axes = "YXC"
img = zds.ImageBase(shape, chunk_size=None, source_axes=axes, mode="image")

random.seed(44512)
selection_1 = dict(
(ax, slice(random.randint(-10, 0),
random.randint(1, r_s + 10)))
for ax, r_s in zip(axes, shape)
)

expected_selection_shape = tuple(
selection_1[ax].stop - selection_1[ax].start for ax in axes
)

img_sel_1 = img[selection_1]

assert img_sel_1.shape == expected_selection_shape, \
(f"Expected selection {selection_1} to have shape "
f"{expected_selection_shape}, got {img_sel_1.shape} instead")


@pytest.mark.parametrize("axes, roi, expected_size", [
(None, None, (16, 16, 3)),
(None, slice(None), (16, 16, 3)),
Expand Down
Loading
Loading