Skip to content
Merged
8 changes: 4 additions & 4 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class CacheDataset(Dataset):
def __init__(
self,
data: Sequence,
transform: Union[Sequence[Callable], Callable],
transform: Optional[Union[Sequence[Callable], Callable]] = None,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: Optional[int] = 1,
Expand Down Expand Up @@ -856,7 +856,7 @@ class SmartCacheDataset(Randomizable, CacheDataset):
Args:
data: input data to load and transform to generate dataset for model.
transform: transforms to execute operations on input data.
replace_rate: percentage of the cached items to be replaced in every epoch.
replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1).
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
Expand All @@ -883,8 +883,8 @@ class SmartCacheDataset(Randomizable, CacheDataset):
def __init__(
self,
data: Sequence,
transform: Union[Sequence[Callable], Callable],
replace_rate: float,
transform: Optional[Union[Sequence[Callable], Callable]] = None,
replace_rate: float = 0.1,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_init_workers: Optional[int] = 1,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_cachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_shape(self, transform, expected_shape):
data4 = dataset[-1]
self.assertEqual(len(data3), 1)

if transform is None:
# Check without providing transfrom
dataset2 = CacheDataset(data=test_data, cache_rate=0.5, as_contiguous=True)
for k in ["image", "label", "extra"]:
self.assertEqual(dataset[0][k], dataset2[0][k])

if transform is None:
self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz"))
self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz"))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_smartcachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def test_shape(self, replace_rate, num_replace_workers, transform):
num_init_workers=4,
num_replace_workers=num_replace_workers,
)
if transform is None:
# Check without providing transfrom
dataset2 = SmartCacheDataset(
data=test_data,
replace_rate=replace_rate,
cache_num=16,
num_init_workers=4,
num_replace_workers=num_replace_workers,
)
for k in ["image", "label", "extra"]:
self.assertEqual(dataset[0][k], dataset2[0][k])

self.assertEqual(len(dataset._cache), dataset.cache_num)
for i in range(dataset.cache_num):
Expand Down
28 changes: 9 additions & 19 deletions tests/test_wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@

TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW

TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel
TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel
TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel
TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels
TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color


Expand All @@ -106,20 +108,6 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str):
return filename


def save_gray_tiff(array: np.ndarray, filename: str):
"""
Save numpy array into a TIFF file

Args:
array: numpy ndarray with any shape
filename: the filename to be used for the tiff file.
"""
img_gray = array
imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack")

return filename


@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
def setUpModule(): # noqa: N802
hash_type = testing_data_config("images", FILE_KEY, "hash_type")
Expand Down Expand Up @@ -187,13 +175,15 @@ def test_read_rgba(self, img_expected):
self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))

@parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D])
@parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D])
@skipUnless(has_tiff, "Requires tifffile.")
def test_read_malformats(self, img_expected):
if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1):
# Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230
return
reader = WSIReader(self.backend)
file_path = save_gray_tiff(
img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
)
file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
imwrite(file_path, img_expected, shape=img_expected.shape)
with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)):
with reader.read(file_path) as img_obj:
reader.get_data(img_obj)
Expand Down
16 changes: 10 additions & 6 deletions tests/test_wsireader_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@

TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW

TEST_CASE_ERROR_GRAY = [np.ones((16, 16, 2), dtype=np.uint8)] # wrong color channel
TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel
TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel
TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels
TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color


Expand Down Expand Up @@ -103,7 +105,7 @@ def save_gray_tiff(array: np.ndarray, filename: str):
filename: the filename to be used for the tiff file.
"""
img_gray = array
imwrite(filename, img_gray, shape=img_gray.shape, photometric="minisblack")
imwrite(filename, img_gray, shape=img_gray.shape)

return filename

Expand Down Expand Up @@ -180,13 +182,15 @@ def test_read_rgba(self, img_expected):
self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))

@parameterized.expand([TEST_CASE_ERROR_GRAY, TEST_CASE_ERROR_3D])
@parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D])
@skipUnless(has_tiff, "Requires tifffile.")
def test_read_malformats(self, img_expected):
if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1):
# Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230
return
reader = WSIReader(self.backend)
file_path = save_gray_tiff(
img_expected, os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
)
file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
imwrite(file_path, img_expected, shape=img_expected.shape)
with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)):
with reader.read(file_path) as img_obj:
reader.get_data(img_obj)
Expand Down