Skip to content
Merged
66 changes: 26 additions & 40 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Optional, Sequence, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset
from typing import Callable, Dict, Iterable, Optional, Sequence, Union

from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple, look_up_option
from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, look_up_option

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"]

Expand Down Expand Up @@ -96,7 +93,7 @@ class GridPatchDataset(IterableDataset):
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)

# construct the dataset
ds = GridPatchDataset(dataset=images,
ds = GridPatchDataset(data=images,
patch_iter=patch_iter,
transform=patch_intensity)
# use the grid patch dataset
Expand All @@ -108,49 +105,34 @@ class GridPatchDataset(IterableDataset):
# coordinates: tensor([[[0, 1], [0, 2], [0, 2]],
# [[0, 1], [2, 4], [0, 2]]])

Args:
data: the data source to read image data from.
patch_iter: converts an input image (item from dataset) into a iterable of image patches.
`patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
see also: :py:class:`monai.data.PatchIter`.
transform: a callable data transform operates on the patches.
with_coordinates: whether to yield the coordinates of each patch, default to `True`.

.. deprecated:: 0.8.0
``dataset`` is deprecated, use ``data`` instead.

"""

@deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.")
def __init__(
self,
dataset: Sequence,
data: Union[Iterable, Sequence],
patch_iter: Callable,
transform: Optional[Callable] = None,
with_coordinates: bool = True,
) -> None:
"""
Initializes this dataset in terms of the image dataset, patch generator, and an optional transform.

Args:
dataset: the dataset to read image data from.
patch_iter: converts an input image (item from dataset) into a iterable of image patches.
`patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
see also: :py:class:`monai.data.PatchIter`.
transform: a callable data transform operates on the patches.
with_coordinates: whether to yield the coordinates of each patch, default to `True`.

"""

self.dataset = dataset
super().__init__(data=data, transform=None)
self.patch_iter = patch_iter
self.transform = transform
self.with_coordinates = with_coordinates

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
iter_start, iter_end = 0, 1
try:
iter_end = len(self.dataset) # TODO: support iterable self.dataset
except TypeError:
raise NotImplementedError("image dataset must implement `len()`.") from None

if worker_info is not None:
# split workload
per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers)))
iter_start += worker_info.id * per_worker
iter_end = min(iter_start + per_worker, iter_end)

for index in range(iter_start, iter_end):
image = self.dataset[index]
for image in super().__iter__():
if not self.with_coordinates:
for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch
out_patch = (
Expand Down Expand Up @@ -204,20 +186,24 @@ class PatchDataset(Dataset):

>>> torch.Size([2, 1, 3, 3])

.. deprecated:: 0.8.0
``dataset`` is deprecated, use ``data`` instead.

"""

@deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.")
def __init__(
self, dataset: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None
self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None
) -> None:
"""
Args:
dataset: an image dataset to extract patches from.
data: an image dataset to extract patches from.
patch_func: converts an input image (item from dataset) into a sequence of image patches.
patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`).
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
transform: transform applied to each patch.
"""
super().__init__(data=dataset, transform=transform)
super().__init__(data=data, transform=transform)

self.patch_func = patch_func
if samples_per_image <= 0:
Expand Down
22 changes: 13 additions & 9 deletions tests/test_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,43 @@ def tearDown(self):
set_determinism(None)

def test_shape(self):
test_dataset = ["vwxyz", "helloworld", "worldfoobar"]
result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False)
# test Iterable input data
test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"])
result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False)
output = []
n_workers = 0 if sys.platform == "win32" else 2
for item in DataLoader(result, batch_size=3, num_workers=n_workers):
output.append("".join(item))
expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"]
if sys.platform == "win32":
expected = ["ar", "ell", "ldf", "oob", "owo", "rld", "vwx", "wor", "yzh"]
else:
expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"]
self.assertEqual(len("".join(expected)), len("".join(list(test_dataset))))
self.assertEqual(sorted(output), sorted(expected))
self.assertEqual(len("".join(expected)), len("".join(test_dataset)))

def test_loading_array(self):
set_determinism(seed=1234)
# image dataset
# test sequence input data with images
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
# image level
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
ds = GridPatchDataset(dataset=images, patch_iter=patch_iter, transform=patch_intensity)
ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[2.0577, 3.0577], [6.0577, 7.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
rtol=1e-5,
np.array([[[[1.4965, 2.4965], [5.4965, 6.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
if sys.platform != "win32":
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[1.6533, 2.6533], [5.6533, 6.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
np.array([[[[1.2548, 2.2548], [5.2548, 6.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_patch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_shape(self):
test_dataset = ["vwxyz", "hello", "world"]
n_per_image = len(test_dataset[0])

result = PatchDataset(dataset=test_dataset, patch_func=identity, samples_per_image=n_per_image)
result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image)

output = []
n_workers = 0 if sys.platform == "win32" else 2
Expand All @@ -50,7 +50,7 @@ def test_loading_array(self):
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
image_ds = Dataset(images, transform=patch_intensity)
# patch level
ds = PatchDataset(dataset=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity)
ds = PatchDataset(data=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity)

np.testing.assert_equal(len(ds), n_samples * len(images))
# use the patch dataset, length: len(images) x samplers_per_image
Expand Down