diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f1b481d598..f0416a6bdb 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -32,7 +32,7 @@ from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform -from monai.utils import MAX_SEED, ensure_tuple, get_seed, look_up_option, min_version, optional_import +from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first if TYPE_CHECKING: @@ -1222,8 +1222,9 @@ class CSVDataset(Dataset): ] Args: - filename: the filename of expected CSV file to load. if providing a list - of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide pandas `DataFrame` directly, will skip loading from filename. + if provided a list of filenames or pandas `DataFrame`, it will join the tables. row_indices: indices of the expected rows to load. it should be a list, every item can be a int number or a range `[start, end)` for the indices. for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None, @@ -1249,11 +1250,15 @@ class CSVDataset(Dataset): transform: transform to apply on the loaded items of a dictionary data. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame` row_indices: Optional[Sequence[Union[int, str]]] = None, col_names: Optional[Sequence[str]] = None, col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, @@ -1261,8 +1266,16 @@ def __init__( transform: Optional[Callable] = None, **kwargs, ): - files = ensure_tuple(filename) - dfs = [pd.read_csv(f) for f in files] + srcs = (src,) if not isinstance(src, (tuple, list)) else src + dfs: List = [] + for i in srcs: + if isinstance(i, str): + dfs.append(pd.read_csv(i)) + elif isinstance(i, pd.DataFrame): + dfs.append(i) + else: + raise ValueError("`src` must be file path or pandas `DataFrame`.") + data = convert_tables_to_dicts( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index d1365fa220..e463c8c6ed 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union import numpy as np from torch.utils.data import IterableDataset as _TorchIterableDataset @@ -18,7 +18,7 @@ from monai.data.utils import convert_tables_to_dicts from monai.transforms import apply_transform from monai.transforms.transform import Randomizable -from monai.utils import ensure_tuple, optional_import +from monai.utils import deprecated_arg, optional_import pd, _ = optional_import("pandas") @@ -147,8 +147,9 @@ class CSVIterableDataset(IterableDataset): ] Args: - filename: the filename of CSV file to load. it can be a str, URL, path object or file-like object. - if providing a list of filenames, it will load all the files and join tables. + src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load. + also support to provide iter for stream input directly, will skip loading from filename. + if provided a list of filenames or iters, it will join the tables. chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html. buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`. @@ -177,11 +178,15 @@ class CSVIterableDataset(IterableDataset): https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. kwargs: additional arguments for `pandas.merge()` API to join tables. + .. deprecated:: 0.8.0 + ``filename`` is deprecated, use ``src`` instead. + """ + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") def __init__( self, - filename: Union[str, Sequence[str]], + src: Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]], chunksize: int = 1000, buffer_size: Optional[int] = None, col_names: Optional[Sequence[str]] = None, @@ -192,7 +197,7 @@ def __init__( seed: int = 0, **kwargs, ): - self.files = ensure_tuple(filename) + self.src = src self.chunksize = chunksize self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size self.col_names = col_names @@ -201,16 +206,46 @@ def __init__( self.shuffle = shuffle self.seed = seed self.kwargs = kwargs - self.iters = self.reset() + self.iters: List[Iterable] = self.reset() super().__init__(data=None, transform=transform) # type: ignore - def reset(self, filename: Optional[Union[str, Sequence[str]]] = None): - if filename is not None: - # update files if necessary - self.files = ensure_tuple(filename) - self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files] + @deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.") + def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, Sequence[Iterable]]]] = None): + """ + Reset the pandas `TextFileReader` iterable object to read data. For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + Args: + src: if not None and provided the filename of CSV file, it can be a str, URL, path object + or file-like object to load. also support to provide iter for stream input directly, + will skip loading from filename. if provided a list of filenames or iters, it will join the tables. + default to `self.src`. + + """ + src = self.src if src is None else src + srcs = (src,) if not isinstance(src, (tuple, list)) else src + self.iters = [] + for i in srcs: + if isinstance(i, str): + self.iters.append(pd.read_csv(i, chunksize=self.chunksize)) + elif isinstance(i, Iterable): + self.iters.append(i) + else: + raise ValueError("`src` must be file path or iterable object.") return self.iters + def close(self): + """ + Close the pandas `TextFileReader` iterable objects. + If the input src is file path, TextFileReader was created internally, need to close it. + If the input src is iterable object, depends on users requirements whether to close it in this function. + For more details, please check: + https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration. + + """ + for i in self.iters: + i.close() + def _flattened(self): for chunks in zip(*self.iters): yield from convert_tables_to_dicts( diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index b136eb7f1f..9b58be04e8 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -19,7 +19,6 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) - else: def monai_mish(x, inplace: bool = False): @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) - else: def monai_swish(x, inplace: bool = False): diff --git a/tests/test_csv_dataset.py b/tests/test_csv_dataset.py index d187f4e64d..2a38409ba0 100644 --- a/tests/test_csv_dataset.py +++ b/tests/test_csv_dataset.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVDataset from monai.transforms import ToNumpyd @@ -57,6 +58,7 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) @@ -76,7 +78,7 @@ def prepare_csv_file(data, filepath): ) # test multiple CSV files, join tables with kwargs - dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id") + dataset = CSVDataset(filepaths, on="subject_id") self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()}, { @@ -102,7 +104,7 @@ def prepare_csv_file(data, filepath): # test selected rows and columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[[0, 2], 3], # load row: 0, 1, 3 col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], ) @@ -120,7 +122,7 @@ def prepare_csv_file(data, filepath): # test group columns dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, row_indices=[1, 3], # load row: 1, 3 col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, @@ -133,9 +135,7 @@ def prepare_csv_file(data, filepath): # test transform dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], - col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, - transform=ToNumpyd(keys="ehr"), + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr") ) self.assertEqual(len(dataset), 5) expected = [ @@ -151,7 +151,7 @@ def prepare_csv_file(data, filepath): # test default values and dtype dataset = CSVDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"], col_types={"image": {"type": str, "default": "No image"}, "ehr_1": {"type": int, "default": 0}}, how="outer", # generate NaN values in this merge mode @@ -161,6 +161,29 @@ def prepare_csv_file(data, filepath): self.assertEqual(type(dataset[-1]["ehr_1"]), int) np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2) + # test pre-loaded DataFrame + df = pd.read_csv(filepath1) + dataset = CSVDataset(src=df) + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + + # test pre-loaded multiple DataFrames, join tables with kwargs + dfs = [pd.read_csv(i) for i in filepaths] + dataset = CSVDataset(src=dfs, on="subject_id") + self.assertEqual(dataset[3]["subject_id"], "s000003") + self.assertEqual(dataset[3]["label"], 1) + self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333) + self.assertEqual(dataset[3]["meta_0"], False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index fae8e0ba8d..2d2ea2a9c0 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -15,6 +15,7 @@ import unittest import numpy as np +import pandas as pd from monai.data import CSVIterableDataset, DataLoader from monai.transforms import ToNumpyd @@ -58,6 +59,7 @@ def prepare_csv_file(data, filepath): filepath1 = os.path.join(tempdir, "test_data1.csv") filepath2 = os.path.join(tempdir, "test_data2.csv") filepath3 = os.path.join(tempdir, "test_data3.csv") + filepaths = [filepath1, filepath2, filepath3] prepare_csv_file(test_data1, filepath1) prepare_csv_file(test_data2, filepath2) prepare_csv_file(test_data3, filepath3) @@ -81,18 +83,20 @@ def prepare_csv_file(data, filepath): ) break self.assertEqual(count, 3) + dataset.close() # test reset iterables - dataset.reset(filename=filepath3) + dataset.reset(src=filepath3) count = 0 for i, item in enumerate(dataset): count += 1 if i == 4: self.assertEqual(item["meta_0"], False) self.assertEqual(count, 5) + dataset.close() # test multiple CSV files, join tables with kwargs - dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id", shuffle=False) + dataset = CSVIterableDataset(filepaths, on="subject_id", shuffle=False) count = 0 for item in dataset: count += 1 @@ -120,13 +124,11 @@ def prepare_csv_file(data, filepath): }, ) self.assertEqual(count, 5) + dataset.close() # test selected columns and chunk size dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], - chunksize=2, - col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], - shuffle=False, + src=filepaths, chunksize=2, col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], shuffle=False ) count = 0 for item in dataset: @@ -143,10 +145,11 @@ def prepare_csv_file(data, filepath): }, ) self.assertEqual(count, 5) + dataset.close() # test group columns dataset = CSVIterableDataset( - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, shuffle=False, @@ -161,12 +164,13 @@ def prepare_csv_file(data, filepath): ) np.testing.assert_allclose(item["meta12"], [False, True]) self.assertEqual(count, 5) + dataset.close() # test transform dataset = CSVIterableDataset( chunksize=2, buffer_size=4, - filename=[filepath1, filepath2, filepath3], + src=filepaths, col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr"), shuffle=True, @@ -185,6 +189,7 @@ def prepare_csv_file(data, filepath): self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) self.assertEqual(count, 5) + dataset.close() # test multiple processes loading dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label"), shuffle=False) @@ -200,6 +205,45 @@ def prepare_csv_file(data, filepath): np.testing.assert_allclose(item["label"], [4]) self.assertListEqual(item["image"], ["./imgs/s000002.png"]) self.assertEqual(count, 3) + dataset.close() + + # test iterable stream + iters = pd.read_csv(filepath1, chunksize=1000) + dataset = CSVIterableDataset(src=iters, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: + self.assertDictEqual( + {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, + { + "subject_id": "s000002", + "label": 4, + "image": "./imgs/s000002.png", + "ehr_0": 3.7725, + "ehr_1": 4.2118, + "ehr_2": 4.6353, + }, + ) + break + self.assertEqual(count, 3) + dataset.close() + + # test multiple iterable streams, join tables with kwargs + iters = [pd.read_csv(i, chunksize=1000) for i in filepaths] + dataset = CSVIterableDataset(src=iters, on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: + self.assertEqual(item["subject_id"], "s000003") + self.assertEqual(item["label"], 1) + self.assertEqual(round(item["ehr_0"], 4), 3.3333) + self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) + # manually close the pre-loaded iterables instead of `dataset.close()` + for i in iters: + i.close() if __name__ == "__main__":