diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py index 42b10f04..b2c7b00e 100644 --- a/src/plaid/containers/dataset.py +++ b/src/plaid/containers/dataset.py @@ -27,8 +27,13 @@ import numpy as np import yaml +from packaging.specifiers import SpecifierSet +from packaging.version import Version +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass from tqdm import tqdm +import plaid from plaid.constants import AUTHORIZED_INFO_KEYS from plaid.containers.sample import Sample from plaid.containers.utils import check_features_size_homogeneity @@ -62,6 +67,62 @@ def process_sample(path: Union[str, Path]) -> tuple: # pragma: no cover # %% Classes +class DataclassToDict: + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, key, value): + assert hasattr(self, key), ( + f"{self.__class__.__name__} has no attribute '{key}'." + ) + return setattr(self, key, value) + + +@dataclass() +class LegalDescription(DataclassToDict): + """Container for legal information.""" + + owner: Optional[str] = None + license: Optional[str] = None + + +@dataclass() +class DataProductionDescription(DataclassToDict): + """Container for data production information.""" + + owner: Optional[str] = None + license: Optional[str] = None + type: Optional[str] = None + physics: Optional[str] = None + simulator: Optional[str] = None + hardware: Optional[str] = None + computation_duration: Optional[str] = None + script: Optional[str] = None + contact: Optional[str] = None + location: Optional[str] = None + + +@dataclass() +class DataDescription(DataclassToDict): + """Container for data information.""" + + number_of_samples: Optional[str] = None + number_of_splits: Optional[str] = None + DOE: Optional[str] = None + inputs: Optional[str] = None + outputs: Optional[str] = None + + +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class DatasetInfos(DataclassToDict): + """Container for dataset information.""" + + version: Union[Version, SpecifierSet] = Version(plaid.__version__) + legal: Optional[LegalDescription] = None + data_production: Optional[DataProductionDescription] = None + data_description: Optional[DataDescription] = None + + class Dataset(object): """A set of samples, and optionnaly some other informations about the Dataset.""" @@ -113,7 +174,7 @@ def __init__( """ self._samples: dict[int, Sample] = {} # sample_id -> sample # info_name -> description - self._infos: dict[str, dict[str, str]] = {} + self._infos = DatasetInfos() if directory_path is not None: if path is not None: @@ -1470,29 +1531,32 @@ def _save_to_dir_(self, path: Union[str, Path], verbose: bool = False) -> None: if verbose: # pragma: no cover print(f"Saving database to: {path}") + # Save infos + assert "plaid" in self._infos + assert "version" in self._infos["plaid"] + plaid_version = Version(plaid.__version__) + if ( + isinstance(self._infos["plaid"]["version"], SpecifierSet) + or self._infos["plaid"]["version"] != plaid_version + ): + logger.warning( + f"Version mismatch: Dataset was loaded from version: {self._infos['plaid']['version']}, and will be saved with version: {plaid_version}" + ) + self._infos["plaid"]["old_version"] = str(self._infos["plaid"]["version"]) + self._infos["plaid"]["version"] = str(plaid_version) + infos_fname = path / "infos.yaml" + with open(infos_fname, "w") as file: + yaml.dump(self._infos, file, default_flow_style=False, sort_keys=False) + + # Save samples samples_dir = path / "samples" if not (samples_dir.is_dir()): samples_dir.mkdir(parents=True) - # ---# save samples for i_sample, sample in tqdm(self._samples.items(), disable=not (verbose)): sample_fname = samples_dir / f"sample_{i_sample:09d}" sample.save(sample_fname) - # ---# save infos - if len(self._infos) > 0: - infos_fname = path / "infos.yaml" - with open(infos_fname, "w") as file: - yaml.dump(self._infos, file, default_flow_style=False, sort_keys=False) - - # #---# save stats - # stats_fname = path / 'stats.yaml' - # self._stats.save(stats_fname) - - # #---# save flags - # flags_fname = path / 'flags.yaml' - # self._flags.save(flags_fname) - def _load_from_dir_( self, path: Union[str, Path],