Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,7 @@ class CSVDataset(Dataset):
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
transform: transform to apply on the loaded items of a dictionary data.
kwargs_read_csv: dictionary args to pass to pandas `read_csv` function.
kwargs: additional arguments for `pandas.merge()` API to join tables.

.. deprecated:: 0.8.0
Expand All @@ -1358,13 +1359,14 @@ def __init__(
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
kwargs_read_csv: Optional[Dict] = None,
**kwargs,
):
srcs = (src,) if not isinstance(src, (tuple, list)) else src
dfs: List = []
for i in srcs:
if isinstance(i, str):
dfs.append(pd.read_csv(i))
dfs.append(pd.read_csv(i, **kwargs_read_csv) if kwargs_read_csv else pd.read_csv(i))
elif isinstance(i, pd.DataFrame):
dfs.append(i)
else:
Expand Down
5 changes: 4 additions & 1 deletion monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class CSVIterableDataset(IterableDataset):
seed: random seed to initialize the random state for all the workers if `shuffle` is True,
set `seed += 1` in every iter() call, refer to the PyTorch idea:
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. Default to ``{"chunksize", chunksize}``.
kwargs: additional arguments for `pandas.merge()` API to join tables.

.. deprecated:: 0.8.0
Expand All @@ -195,6 +196,7 @@ def __init__(
transform: Optional[Callable] = None,
shuffle: bool = False,
seed: int = 0,
kwargs_read_csv: Optional[Dict] = None,
**kwargs,
):
self.src = src
Expand All @@ -205,6 +207,7 @@ def __init__(
self.col_groups = col_groups
self.shuffle = shuffle
self.seed = seed
self.kwargs_read_csv = kwargs_read_csv or {"chunksize", chunksize}
# in case treating deprecated arg `filename` as kwargs, remove it from `kwargs`
kwargs.pop("filename", None)
self.kwargs = kwargs
Expand All @@ -230,7 +233,7 @@ def reset(self, src: Optional[Union[Union[str, Sequence[str]], Union[Iterable, S
self.iters = []
for i in srcs:
if isinstance(i, str):
self.iters.append(pd.read_csv(i, chunksize=self.chunksize))
self.iters.append(pd.read_csv(i, **self.kwargs_read_csv))
elif isinstance(i, Iterable):
self.iters.append(i)
else:
Expand Down