Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
c52d484
Merge pull request #124 from Project-MONAI/dev
Nic-Ma Jun 7, 2021
c298e1d
Merge pull request #126 from Project-MONAI/dev
Nic-Ma Jun 9, 2021
5c3d9f0
[DKMED] add CSV datalist
Nic-Ma Jun 9, 2021
372256e
[DLMED] add group feature
Nic-Ma Jun 11, 2021
58a0fa7
[DLMED] add unit test
Nic-Ma Jun 11, 2021
e5d97f8
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 11, 2021
cf817bb
[DLMED] add more unit tests
Nic-Ma Jun 11, 2021
9ca6d07
[DLMED] add optional install
Nic-Ma Jun 11, 2021
6adc759
[MONAI] python code formatting
monai-bot Jun 11, 2021
02f14e8
[DLMED] fix flake8 issue
Nic-Ma Jun 11, 2021
6aa97b9
[DLMED] add doc-strings
Nic-Ma Jun 11, 2021
c1c5201
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 12, 2021
f38e825
Merge branch 'dev' into 2310-csv-datalist
ericspod Jun 17, 2021
b7e0277
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 20, 2021
fcef348
[DLMED] fix typo
Nic-Ma Jun 20, 2021
d4cdf30
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 21, 2021
3408ffb
[DLMED] add CSVDataset for non-iterable data
Nic-Ma Jun 21, 2021
4ccd36b
[DLMED] fix min test
Nic-Ma Jun 21, 2021
1aebfb7
[DLMED] add CSVIterableDataset base
Nic-Ma Jun 21, 2021
bca5afa
[DLMED] add CSVIterableDataset
Nic-Ma Jun 21, 2021
8a169bd
[DLMED] support multiple processes
Nic-Ma Jun 21, 2021
5f6bae7
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 21, 2021
000372d
[DLMED] fix tests
Nic-Ma Jun 21, 2021
bdd67c7
[DLMED] fix flake8
Nic-Ma Jun 21, 2021
356e339
[DLMED] fix docs-build
Nic-Ma Jun 21, 2021
1903529
[DLMED] fix min tests
Nic-Ma Jun 21, 2021
48d4ef7
[DLMED] fix CI tests
Nic-Ma Jun 21, 2021
e1e3273
[MONAI] python code formatting
monai-bot Jun 21, 2021
5195b08
[DLMED] fix typo
Nic-Ma Jun 21, 2021
95f25a8
[DLMED] change sys.platform
Nic-Ma Jun 21, 2021
14f3a8e
[DLMED] skip if windows
Nic-Ma Jun 21, 2021
780ca06
[MONAI] python code formatting
monai-bot Jun 21, 2021
6e3b6d8
Merge branch 'dev' into 2310-csv-datalist
wyli Jun 21, 2021
ed3eb8b
Merge branch 'dev' into 2310-csv-datalist
Nic-Ma Jun 22, 2021
8207fe3
[DLMED] add col_types arg
Nic-Ma Jun 22, 2021
9e08e1d
[MONAI] python code formatting
monai-bot Jun 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sphinxcontrib-jsmath
sphinxcontrib-qthelp
sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
12 changes: 12 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Generic Interfaces
:members:
:special-members: __next__

`CSVIterableDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CSVIterableDataset
:members:
:special-members: __next__

`PersistentDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: PersistentDataset
Expand Down Expand Up @@ -75,6 +81,12 @@ Generic Interfaces
:members:
:special-members: __getitem__

`CSVDataset`
~~~~~~~~~~~~
.. autoclass:: CSVDataset
:members:
:special-members: __getitem__

Patch-based dataset
-------------------

Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb` and `psutil`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim` `openslide-python` and `pandas`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def get_optional_config_values():
output["tqdm"] = get_package_version("tqdm")
output["lmdb"] = get_package_version("lmdb")
output["psutil"] = psutil_version
output["pandas"] = get_package_version("pandas")

return output

Expand Down
4 changes: 3 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ArrayDataset,
CacheDataset,
CacheNTransDataset,
CSVDataset,
Dataset,
LMDBDataset,
NPZDictItemDataset,
Expand All @@ -26,7 +27,7 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .iterable_dataset import IterableDataset
from .iterable_dataset import CSVIterableDataset, IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand All @@ -38,6 +39,7 @@
from .utils import (
compute_importance_map,
compute_shape_offset,
convert_tables_to_dicts,
correct_nifti_header_if_necessary,
create_file_basename,
decollate_batch,
Expand Down
74 changes: 72 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset

from monai.data.utils import first, pickle_hashing
from monai.data.utils import convert_tables_to_dicts, first, pickle_hashing
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import
from monai.utils import MAX_SEED, ensure_tuple, get_seed, min_version, optional_import

if TYPE_CHECKING:
from tqdm import tqdm
Expand All @@ -41,6 +41,7 @@
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")


class Dataset(_TorchDataset):
Expand Down Expand Up @@ -1061,3 +1062,72 @@ def _transform(self, index: int):
data = apply_transform(self.transform, data)

return data


class CSVDataset(Dataset):
"""
Dataset to load data from CSV files and generate a list of dictionaries,
every dictionay maps to a row of the CSV file, and the keys of dictionary
map to the column names of the CSV file.

It can load multiple CSV files and join the tables with addtional `kwargs` arg.
Support to only load specific rows and columns.
And it can also group several loaded columns to generate a new column, for example,
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::

[
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]

Args:
filename: the filename of expected CSV file to load. if providing a list
of filenames, it will load all the files and join tables.
row_indices: indices of the expected rows to load. it should be a list,
every item can be a int number or a range `[start, end)` for the indices.
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
load all the rows in the file.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"]. for example::

col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
"image": {"type": str, "default": None},
}

col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
transform: transform to apply on the loaded items of a dictionary data.
kwargs: additional arguments for `pandas.merge()` API to join tables.

"""

def __init__(
self,
filename: Union[str, Sequence[str]],
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
**kwargs,
):
files = ensure_tuple(filename)
dfs = [pd.read_csv(f) for f in files]
data = convert_tables_to_dicts(
dfs=dfs,
row_indices=row_indices,
col_names=col_names,
col_types=col_types,
col_groups=col_groups,
**kwargs,
)
super().__init__(data=data, transform=transform)
99 changes: 98 additions & 1 deletion monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, Optional
import math
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union

from torch.utils.data import IterableDataset as _TorchIterableDataset
from torch.utils.data import get_worker_info

from monai.data.utils import convert_tables_to_dicts
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, optional_import

pd, _ = optional_import("pandas")


class IterableDataset(_TorchIterableDataset):
Expand Down Expand Up @@ -43,3 +49,94 @@ def __iter__(self):
if self.transform is not None:
data = apply_transform(self.transform, data)
yield data


class CSVIterableDataset(IterableDataset):
"""
Iterable dataset to load CSV files and generate dictionary data.
It can be helpful when loading extemely big CSV files that can't read into memory directly.
To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,
every process executes tranforms on part of every loaded chunk.
Note: the order of output data may not match data source in multi-processing mode.

It can load data from multiple CSV files and join the tables with addtional `kwargs` arg.
Support to only load specific columns.
And it can also group several loaded columns to generate a new column, for example,
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::

[
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
]

Args:
filename: the filename of expected CSV file to load. if providing a list
of filenames, it will load all the files and join tables.
chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"]. for example::

col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
"image": {"type": str, "default": None},
}

col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
transform: transform to apply on the loaded items of a dictionary data.
kwargs: additional arguments for `pandas.merge()` API to join tables.

"""

def __init__(
self,
filename: Union[str, Sequence[str]],
chunksize: int = 1000,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
**kwargs,
):
self.files = ensure_tuple(filename)
self.chunksize = chunksize
self.iters = self.reset()
self.col_names = col_names
self.col_types = col_types
self.col_groups = col_groups
self.kwargs = kwargs
super().__init__(data=None, transform=transform) # type: ignore

def reset(self, filename: Optional[Union[str, Sequence[str]]] = None):
if filename is not None:
# update files if necessary
self.files = ensure_tuple(filename)
self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files]
return self.iters

def __iter__(self):
for chunks in zip(*self.iters):
self.data = convert_tables_to_dicts(
dfs=chunks,
col_names=self.col_names,
col_types=self.col_types,
col_groups=self.col_groups,
**self.kwargs,
)
info = get_worker_info()
if info is not None:
length = len(self.data)
per_worker = int(math.ceil(length / float(info.num_workers)))
start = info.id * per_worker
self.data = self.data[start : min(start + per_worker, length)]

return super().__iter__()
84 changes: 83 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import pickle
import warnings
from collections import defaultdict
from functools import reduce
from itertools import product, starmap
from pathlib import PurePath
from typing import Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -37,8 +38,11 @@
)
from monai.utils.enums import Method

pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")


__all__ = [
"get_random_patch",
"iter_patch_slices",
Expand All @@ -65,6 +69,7 @@
"decollate_batch",
"pad_list_data_collate",
"no_collation",
"convert_tables_to_dicts",
]


Expand Down Expand Up @@ -983,3 +988,80 @@ def sorted_dict(item, key=None, reverse=False):
if not isinstance(item, dict):
return item
return {k: sorted_dict(v) if isinstance(v, dict) else v for k, v in sorted(item.items(), key=key, reverse=reverse)}


def convert_tables_to_dicts(
dfs,
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
col_groups: Optional[Dict[str, Sequence[str]]] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Utility to join pandas tables, select rows, columns and generate groups.
Will return a list of dictionaries, every dictionary maps to a row of data in tables.

Args:
dfs: data table in pandas Dataframe format. if providing a list of tables, will join them.
row_indices: indices of the expected rows to load. it should be a list,
every item can be a int number or a range `[start, end)` for the indices.
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
load all the rows in the file.
col_names: names of the expected columns to load. if None, load all the columns.
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
it should be a dictionary, every item maps to an expected column, the `key` is the column
name and the `value` is None or a dictionary to define the default value and data type.
the supported keys in dictionary are: ["type", "default"], and note that the value of `default`
should not be `None`. for example::

col_types = {
"subject_id": {"type": str},
"label": {"type": int, "default": 0},
"ehr_0": {"type": float, "default": 0.0},
"ehr_1": {"type": float, "default": 0.0},
}

col_groups: args to group the loaded columns to generate a new column,
it should be a dictionary, every item maps to a group, the `key` will
be the new column name, the `value` is the names of columns to combine. for example:
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
kwargs: additional arguments for `pandas.merge()` API to join tables.

"""
df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs))
# parse row indices
rows: List[Union[int, str]] = []
if row_indices is None:
rows = slice(df.shape[0]) # type: ignore
else:
for i in row_indices:
if isinstance(i, (tuple, list)):
if len(i) != 2:
raise ValueError("range of row indices must contain 2 values: start and end.")
rows.extend(list(range(i[0], i[1])))
else:
rows.append(i)

# convert to a list of dictionaries corresponding to every row
data_ = df.loc[rows] if col_names is None else df.loc[rows, col_names]
if isinstance(col_types, dict):
# fill default values for NaN
defaults = {k: v["default"] for k, v in col_types.items() if v is not None and v.get("default") is not None}
if len(defaults) > 0:
data_ = data_.fillna(value=defaults)
# convert data types
types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v}
if len(types) > 0:
data_ = data_.astype(dtype=types)
data: List[Dict] = data_.to_dict(orient="records")

# group columns to generate new column
if col_groups is not None:
groups: Dict[str, List] = {}
for name, cols in col_groups.items():
groups[name] = df.loc[rows, cols].values
# invert items of groups to every row of data
data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]

return data
Loading