diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 156708c7dd..adcfad360f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1342,6 +1342,7 @@ class CSVDataset(Dataset): be the new column name, the `value` is the names of columns to combine. for example: `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}` transform: transform to apply on the loaded items of a dictionary data. + kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. kwargs: additional arguments for `pandas.merge()` API to join tables. .. deprecated:: 0.8.0 @@ -1358,13 +1359,14 @@ def __init__( col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None, col_groups: Optional[Dict[str, Sequence[str]]] = None, transform: Optional[Callable] = None, + kwargs_read_csv: Optional[Dict] = None, **kwargs, ): srcs = (src,) if not isinstance(src, (tuple, list)) else src dfs: List = [] for i in srcs: if isinstance(i, str): - dfs.append(pd.read_csv(i)) + dfs.append(pd.read_csv(i, **kwargs_read_csv) if kwargs_read_csv else pd.read_csv(i)) elif isinstance(i, pd.DataFrame): dfs.append(i) else: diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 19efc925fc..957d0e8a56 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -176,6 +176,7 @@ class CSVIterableDataset(IterableDataset): seed: random seed to initialize the random state for all the workers if `shuffle` is True, set `seed += 1` in every iter() call, refer to the PyTorch idea: https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. + kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. Default to ``{"chunksize", chunksize}``. kwargs: additional arguments for `pandas.merge()` API to join tables. .. deprecated:: 0.8.0 @@ -195,6 +196,7 @@ def __init__( transform: Optional[Callable] = None, shuffle: bool = False, seed: int = 0, + kwargs_read_csv: Optional[Dict] = None, **kwargs, ): self.src = src @@ -205,6 +207,7 @@ def __init__( self.col_groups = col_groups self.shuffle = shuffle self.seed = seed + self.kwargs_read_csv = kwargs_read_csv or {"chunksize", chunksize} # in case treating deprecated arg `filename` as kwargs, remove it from `kwargs` kwargs.pop("filename", None) self.kwargs = kwargs @@ -230,7 +233,7 @@ def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, S self.iters = [] for i in srcs: if isinstance(i, str): - self.iters.append(pd.read_csv(i, chunksize=self.chunksize)) + self.iters.append(pd.read_csv(i, **self.kwargs_read_csv)) elif isinstance(i, Iterable): self.iters.append(i) else: