Skip to content

Commit

Permalink
Refactor Dataset to use Compose for transforms (#7784)
Browse files Browse the repository at this point in the history
Fixes #7646 

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Suraj Pai <b.pai@maastrichtuniversity.nl>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Ben Murray <ben.murray@gmail.com>
  • Loading branch information
4 people committed May 31, 2024
1 parent 0d7f772 commit 4029c42
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 40 deletions.
52 changes: 15 additions & 37 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,7 @@

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import (
Compose,
Randomizable,
RandomizableTrait,
Transform,
apply_transform,
convert_to_contiguous,
reset_ops_id,
)
from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

Expand Down Expand Up @@ -77,15 +69,19 @@ class Dataset(_TorchDataset):
}, }, }]
"""

def __init__(self, data: Sequence, transform: Callable | None = None) -> None:
def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.
transform: a callable, sequence of callables or None. If transform is not
a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
of callables are applied in order and if `None` is passed, the data is returned as is.
"""
self.data = data
self.transform: Any = transform
try:
self.transform = Compose(transform) if not isinstance(transform, Compose) else transform
except Exception as e:
raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e

def __len__(self) -> int:
return len(self.data)
Expand All @@ -95,7 +91,7 @@ def _transform(self, index: int):
Fetch single data item from `self.data`.
"""
data_i = self.data[index]
return apply_transform(self.transform, data_i) if self.transform is not None else data_i
return self.transform(data_i)

def __getitem__(self, index: int | slice | Sequence[int]):
"""
Expand Down Expand Up @@ -264,8 +260,6 @@ def __init__(
using the cached content and with re-created transform instances.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
Expand Down Expand Up @@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed):
random transform object
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
Expand All @@ -346,9 +337,6 @@ def _post_transform(self, item_transformed):
the transformed element through the random transforms
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
Expand Down Expand Up @@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed):
Returns:
the transformed element up to the N transform object
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)

reset_ops_id(item_transformed)
Expand All @@ -519,9 +504,6 @@ def _post_transform(self, item_transformed):
Returns:
the final transformed result
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

return self.transform(item_transformed, start=self.cache_n_trans)


Expand Down Expand Up @@ -809,8 +791,6 @@ def __init__(
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
Expand Down Expand Up @@ -1282,8 +1262,10 @@ def to_list(x):
data = []
for dataset in self.data:
data.extend(to_list(dataset[index]))

if self.transform is not None:
data = apply_transform(self.transform, data, map_items=False) # transform the list data
self.transform.map_items = False # Compose object map_items to false so transform is applied to list
data = self.transform(data)
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
return tuple(data)

Expand Down Expand Up @@ -1432,15 +1414,11 @@ def __len__(self):

def _transform(self, index: int):
data = {k: v[index] for k, v in self.arrays.items()}

if not self.transform:
return data

result = apply_transform(self.transform, data)
result = self.transform(data) if self.transform is not None else data

if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)):
return result
raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.")
raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.")


class CSVDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arraydataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

class TestCompose(Compose):

def __call__(self, input_, lazy):
def __call__(self, input_, lazy=False):
img = self.transforms[0](input_)
metadata = img.meta
img = self.transforms[1](img)
Expand Down
68 changes: 67 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from parameterized import parameterized

from monai.data import Dataset
from monai.transforms import Compose, LoadImaged, SimulateDelayd
from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd
from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys

TEST_CASE_1 = [(128, 128, 128)]
Expand Down Expand Up @@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self):
data[0, 0:2, 0:2] = 1


class TestTupleDataset(unittest.TestCase):

@parameterized.expand([TEST_CASE_1])
def test_shape(self, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
test_data = [
(os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")),
(os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")),
]

test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)])

# Here test_transform is applied element by element for the tuple.
dataset = Dataset(data=test_data, transform=test_transform)
data1 = dataset[0]
data2 = dataset[1]

# Output is a list/tuple
self.assertTrue(isinstance(data1, (list, tuple)))
self.assertTrue(isinstance(data2, (list, tuple)))

# Number of elements are 2
self.assertEqual(len(data1), 2)
self.assertEqual(len(data2), 2)

# Output shapes are as expected
self.assertTupleEqual(data1[0].shape, expected_shape)
self.assertTupleEqual(data1[1].shape, expected_shape)
self.assertTupleEqual(data2[0].shape, expected_shape)
self.assertTupleEqual(data2[1].shape, expected_shape)

# Here test_transform is applied to the tuple as a whole.
test_transform = Compose(
[
# LoadImage creates a channel-stacked image when applied to a tuple
LoadImage(),
# Get the channel-stacked image and the label
Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])),
],
map_items=False,
)

dataset = Dataset(data=test_data, transform=test_transform)
data1 = dataset[0]
data2 = dataset[1]

# Output is a list/tuple
self.assertTrue(isinstance(data1, (list, tuple)))
self.assertTrue(isinstance(data2, (list, tuple)))

# Number of elements are 2
self.assertEqual(len(data1), 2)
self.assertEqual(len(data2), 2)

# Output shapes are as expected
self.assertTupleEqual(data1[0].shape, expected_shape)
self.assertTupleEqual(data1[1].shape, expected_shape)
self.assertTupleEqual(data2[0].shape, expected_shape)
self.assertTupleEqual(data2[1].shape, expected_shape)


class TestDatsesetWithLazy(unittest.TestCase):
LOGGER_NAME = "a_logger_name"

Expand Down
4 changes: 3 additions & 1 deletion tests/test_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def setUp(self):

self.scale = mt.ScaleIntensity()
self.scale_call_name = "ScaleIntensity.__call__"
self.compose_call_name = "Compose.__call__"
self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])
self.test_image = torch.rand(1, 16, 16, 16)
self.pid = os.getpid()
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_profile_multithread(self):
self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))

results = wp.get_results()
self.assertSequenceEqual(list(results), [self.scale_call_name])
self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])

prs = results[self.scale_call_name]

Expand All @@ -98,6 +99,7 @@ def test_profile_context(self):
self.scale(self.test_image)

results = wp.get_results()

self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"})

prs = results["context"]
Expand Down

0 comments on commit 4029c42

Please sign in to comment.