Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Dataset to use Compose for transforms #7784

Merged
merged 12 commits into from
May 31, 2024
52 changes: 15 additions & 37 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,7 @@

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import (
Compose,
Randomizable,
RandomizableTrait,
Transform,
apply_transform,
convert_to_contiguous,
reset_ops_id,
)
from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

Expand Down Expand Up @@ -77,15 +69,19 @@ class Dataset(_TorchDataset):
}, }, }]
"""

def __init__(self, data: Sequence, transform: Callable | None = None) -> None:
def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.

transform: a callable, sequence of callables or None. If transform is not
a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
of callables are applied in order and if `None` is passed, the data is returned as is.
"""
self.data = data
self.transform: Any = transform
try:
self.transform = Compose(transform) if not isinstance(transform, Compose) else transform
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e

def __len__(self) -> int:
return len(self.data)
Expand All @@ -95,7 +91,7 @@ def _transform(self, index: int):
Fetch single data item from `self.data`.
"""
data_i = self.data[index]
return apply_transform(self.transform, data_i) if self.transform is not None else data_i
return self.transform(data_i)

def __getitem__(self, index: int | slice | Sequence[int]):
"""
Expand Down Expand Up @@ -264,8 +260,6 @@ def __init__(
using the cached content and with re-created transform instances.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
Expand Down Expand Up @@ -323,9 +317,6 @@ def _pre_transform(self, item_transformed):
random transform object

"""
if not isinstance(self.transform, Compose):
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
Expand All @@ -346,9 +337,6 @@ def _post_transform(self, item_transformed):
the transformed element through the random transforms

"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
Expand Down Expand Up @@ -501,9 +489,6 @@ def _pre_transform(self, item_transformed):
Returns:
the transformed element up to the N transform object
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)

reset_ops_id(item_transformed)
Expand All @@ -519,9 +504,6 @@ def _post_transform(self, item_transformed):
Returns:
the final transformed result
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

return self.transform(item_transformed, start=self.cache_n_trans)


Expand Down Expand Up @@ -809,8 +791,6 @@ def __init__(
Not following these recommendations may lead to runtime errors or duplicated cache across processes.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
Expand Down Expand Up @@ -1282,8 +1262,10 @@ def to_list(x):
data = []
for dataset in self.data:
data.extend(to_list(dataset[index]))

if self.transform is not None:
data = apply_transform(self.transform, data, map_items=False) # transform the list data
self.transform.map_items = False # Compose object map_items to false so transform is applied to list
data = self.transform(data)
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
return tuple(data)

Expand Down Expand Up @@ -1432,15 +1414,11 @@ def __len__(self):

def _transform(self, index: int):
data = {k: v[index] for k, v in self.arrays.items()}

if not self.transform:
return data

result = apply_transform(self.transform, data)
result = self.transform(data) if self.transform is not None else data

if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)):
return result
raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.")
raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.")


class CSVDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arraydataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

class TestCompose(Compose):

def __call__(self, input_, lazy):
def __call__(self, input_, lazy=False):
img = self.transforms[0](input_)
metadata = img.meta
img = self.transforms[1](img)
Expand Down
68 changes: 67 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from parameterized import parameterized

from monai.data import Dataset
from monai.transforms import Compose, LoadImaged, SimulateDelayd
from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd
from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys

TEST_CASE_1 = [(128, 128, 128)]
Expand Down Expand Up @@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self):
data[0, 0:2, 0:2] = 1


class TestTupleDataset(unittest.TestCase):

@parameterized.expand([TEST_CASE_1])
def test_shape(self, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
test_data = [
(os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")),
(os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")),
]

test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)])

# Here test_transform is applied element by element for the tuple.
dataset = Dataset(data=test_data, transform=test_transform)
data1 = dataset[0]
data2 = dataset[1]

# Output is a list/tuple
self.assertTrue(isinstance(data1, (list, tuple)))
self.assertTrue(isinstance(data2, (list, tuple)))

# Number of elements are 2
self.assertEqual(len(data1), 2)
self.assertEqual(len(data2), 2)

# Output shapes are as expected
self.assertTupleEqual(data1[0].shape, expected_shape)
self.assertTupleEqual(data1[1].shape, expected_shape)
self.assertTupleEqual(data2[0].shape, expected_shape)
self.assertTupleEqual(data2[1].shape, expected_shape)

# Here test_transform is applied to the tuple as a whole.
test_transform = Compose(
[
# LoadImage creates a channel-stacked image when applied to a tuple
LoadImage(),
# Get the channel-stacked image and the label
Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])),
],
map_items=False,
)

dataset = Dataset(data=test_data, transform=test_transform)
data1 = dataset[0]
data2 = dataset[1]

# Output is a list/tuple
self.assertTrue(isinstance(data1, (list, tuple)))
self.assertTrue(isinstance(data2, (list, tuple)))

# Number of elements are 2
self.assertEqual(len(data1), 2)
self.assertEqual(len(data2), 2)

# Output shapes are as expected
self.assertTupleEqual(data1[0].shape, expected_shape)
self.assertTupleEqual(data1[1].shape, expected_shape)
self.assertTupleEqual(data2[0].shape, expected_shape)
self.assertTupleEqual(data2[1].shape, expected_shape)


class TestDatsesetWithLazy(unittest.TestCase):
LOGGER_NAME = "a_logger_name"

Expand Down
4 changes: 3 additions & 1 deletion tests/test_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def setUp(self):

self.scale = mt.ScaleIntensity()
self.scale_call_name = "ScaleIntensity.__call__"
self.compose_call_name = "Compose.__call__"
self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])
self.test_image = torch.rand(1, 16, 16, 16)
self.pid = os.getpid()
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_profile_multithread(self):
self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))

results = wp.get_results()
self.assertSequenceEqual(list(results), [self.scale_call_name])
self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])

prs = results[self.scale_call_name]

Expand All @@ -98,6 +99,7 @@ def test_profile_context(self):
self.scale(self.test_image)

results = wp.get_results()

self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"})

prs = results["context"]
Expand Down
Loading