-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Standardized patch sampling method to handle smaller and bigger mask …
…scales than image scale
- Loading branch information
Showing
4 changed files
with
498 additions
and
252 deletions.
There are no files selected for viewing
177 changes: 177 additions & 0 deletions
177
docs/source/examples/advanced_example_pytorch_inference.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
```python | ||
import zarrdataset as zds | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
``` | ||
|
||
|
||
```python | ||
# These are images from the Image Data Resource (IDR) | ||
# https://idr.openmicroscopy.org/ that are publicly available and were | ||
# converted to the OME-NGFF (Zarr) format by the OME group. More examples | ||
# can be found at Public OME-Zarr data (Nov. 2020) | ||
# https://www.openmicroscopy.org/2020/11/04/zarr-data.html | ||
|
||
filenames = [ | ||
"https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr" | ||
] | ||
``` | ||
|
||
|
||
```python | ||
import random | ||
import numpy as np | ||
|
||
# For reproducibility | ||
np.random.seed(478963) | ||
torch.manual_seed(478964) | ||
random.seed(478965) | ||
``` | ||
|
||
## Extracting patches of size 1024x1024 pixels from a Whole Slide Image (WSI) | ||
|
||
Retrieve samples for inference. Add padding to each patch to avoid edge artifacts when stitching the inference result. | ||
Finally, let the PatchSampler retrieve patches from the edge of the image that would be otherwise smaller than the patch size. | ||
|
||
|
||
```python | ||
patch_size = dict(Y=128, X=128) | ||
pad = dict(Y=16, X=16) | ||
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) | ||
``` | ||
|
||
Create a dataset from the list of filenames. All those files should be stored within their respective group "0". | ||
|
||
Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly | ||
|
||
|
||
```python | ||
image_specs = zds.ImagesDatasetSpecs( | ||
filenames=filenames, | ||
data_group="4", | ||
source_axes="TCZYX", | ||
axes="YXC", | ||
roi="0,0,0,0,0:1,-1,1,-1,-1" | ||
) | ||
|
||
my_dataset = zds.ZarrDataset(image_specs, | ||
patch_sampler=patch_sampler, | ||
return_positions=True) | ||
``` | ||
|
||
|
||
```python | ||
my_dataset | ||
``` | ||
|
||
|
||
|
||
|
||
ZarrDataset (PyTorch support:True, tqdm support :True) | ||
Modalities: images | ||
Transforms order: [] | ||
Using images modality as reference. | ||
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. | ||
|
||
|
||
|
||
Add a pre-processing step before creating the image batches, where the input arrays are casted from int16 to float32. | ||
|
||
|
||
```python | ||
import torchvision | ||
|
||
img_preprocessing = torchvision.transforms.Compose([ | ||
zds.ToDtype(dtype=np.float32), | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.Normalize(127, 255) | ||
]) | ||
|
||
my_dataset.add_transform("images", img_preprocessing) | ||
``` | ||
|
||
|
||
```python | ||
my_dataset | ||
``` | ||
|
||
|
||
|
||
|
||
ZarrDataset (PyTorch support:True, tqdm support :True) | ||
Modalities: images | ||
Transforms order: [('images',)] | ||
Using images modality as reference. | ||
Using <class 'zarrdataset._samplers.PatchSampler'> for sampling patches of size {'Z': 1, 'Y': 128, 'X': 128}. | ||
|
||
|
||
|
||
## Create a DataLoader from the dataset object | ||
|
||
ZarrDataset is compatible with DataLoader from PyTorch since it is inherited from the IterableDataset class of the torch.utils.data module. | ||
|
||
|
||
```python | ||
my_dataloader = DataLoader(my_dataset, num_workers=0) | ||
``` | ||
|
||
|
||
```python | ||
import dask.array as da | ||
import numpy as np | ||
import zarr | ||
|
||
z_arr = zarr.open("https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.4/idr0073A/9798462.zarr/4", mode="r") | ||
|
||
H = z_arr.shape[-2] | ||
W = z_arr.shape[-1] | ||
|
||
pad_H = (128 - H) % 128 | ||
pad_W = (128 - W) % 128 | ||
z_prediction = zarr.zeros((H + pad_H, W + pad_W), dtype=np.float32, chunks=(128, 128)) | ||
z_prediction | ||
``` | ||
|
||
|
||
|
||
|
||
<zarr.core.Array (1152, 1408) float32> | ||
|
||
|
||
|
||
Set up a simple model for illustration purpose | ||
|
||
|
||
```python | ||
model = torch.nn.Sequential( | ||
torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1), | ||
torch.nn.ReLU() | ||
) | ||
``` | ||
|
||
|
||
```python | ||
for i, (pos, sample) in enumerate(my_dataloader): | ||
pred_pos = ( | ||
slice(pos[0, 0, 0].item() + 16, | ||
pos[0, 0, 1].item() - 16), | ||
slice(pos[0, 1, 0].item() + 16, | ||
pos[0, 1, 1].item() - 16) | ||
) | ||
pred = model(sample) | ||
z_prediction[pred_pos] = pred.detach().cpu().numpy()[0, 0, 16:-16, 16:-16] | ||
``` | ||
|
||
## Visualize the result | ||
|
||
|
||
```python | ||
import matplotlib.pyplot as plt | ||
|
||
plt.subplot(2, 1, 1) | ||
plt.imshow(np.moveaxis(z_arr[0, :, 0, ...], 0, -1)) | ||
plt.subplot(2, 1, 2) | ||
plt.imshow(z_prediction) | ||
plt.show() | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.