diff --git a/CHANGELOG.md b/CHANGELOG.md index 72a9c928..4d43359d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixes - (plaid/examples) fix circular imports +- (sample/dataset/problem_definition) fix incoherent path argument names in save/load methods -> `path` is now used everywhere ### Removed diff --git a/src/plaid/containers/dataset.py b/src/plaid/containers/dataset.py index f9f29d7d..a28c8f23 100644 --- a/src/plaid/containers/dataset.py +++ b/src/plaid/containers/dataset.py @@ -45,18 +45,18 @@ # %% Functions -def process_sample(sample_path: Union[str, Path]) -> tuple: # pragma: no cover +def process_sample(path: Union[str, Path]) -> tuple: # pragma: no cover """Load Sample from path. Args: - sample_path (Union[str,Path]): The path of the Sample. + path (Union[str,Path]): The path to the Sample. Returns: tuple: The loaded Sample and its ID. """ - sample_path = Path(sample_path) - id = int(sample_path.stem.split("_")[-1]) - return id, Sample(sample_path) + path = Path(path) + id = int(path.stem.split("_")[-1]) + return id, Sample(path) # %% Classes @@ -67,18 +67,20 @@ class Dataset(object): def __init__( self, + path: Optional[Union[str, Path]] = None, directory_path: Optional[Union[str, Path]] = None, verbose: bool = False, processes_number: int = 0, ) -> None: """Initialize a :class:`Dataset `. - If `directory_path` is not specified it initializes an empty :class:`Dataset ` that should be fed with :class:`Samples `. + If `path` is not specified it initializes an empty :class:`Dataset ` that should be fed with :class:`Samples `. Use :meth:`add_sample ` or :meth:`add_samples ` to feed the :class:`Dataset` Args: - directory_path (Union[str,Path], optional): The path from which to load PLAID dataset files. + path (Union[str,Path], optional): The path from which to load PLAID dataset files. + directory_path (Union[str,Path], optional): Deprecated, use `path` instead. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. processes_number (int, optional): Number of processes used to load files (-1 to use all available ressources, 0 to disable multiprocessing). Defaults to 0. @@ -114,15 +116,24 @@ def __init__( self._infos: dict[str, dict[str, str]] = {} if directory_path is not None: - directory_path = Path(directory_path) - - if directory_path.suffix == ".plaid": - self.load( - directory_path, verbose=verbose, processes_number=processes_number + if path is not None: + raise ValueError( + "Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated." ) + else: + path = directory_path + logger.warning( + "DeprecationWarning: 'directory_path' is deprecated, use 'path' instead." + ) + + if path is not None: + path = Path(path) + + if path.suffix == ".plaid": + self.load(path, verbose=verbose, processes_number=processes_number) else: self._load_from_dir_( - directory_path, verbose=verbose, processes_number=processes_number + path, verbose=verbose, processes_number=processes_number ) def copy(self) -> Self: @@ -996,37 +1007,37 @@ def merge_dataset_by_features(cls, datasets_list: list[Self]) -> Self: merged_dataset = merged_dataset.merge_features(dataset, in_place=False) return merged_dataset - def save(self, fname: Union[str, Path]) -> None: + def save(self, path: Union[str, Path]) -> None: """Saves the data set to a TAR (Tape Archive) file. It creates a temporary intermediate directory to store temporary files during the loading process. Args: - fname (Union[str,Path]): The path to which the data set will be saved. + path (Union[str,Path]): The path to which the data set will be saved. Raises: ValueError: If the randomly generated temporary dir name is already used (extremely unlikely!). """ - fname = Path(fname) + path = Path(path) - # First : creates a directory to save everything in an + # First : creates a directory to save everything in an # arborescence on disk - save_dir = fname.parent / f"tmpsavedir_{generate_random_ASCII()}" - if save_dir.is_dir(): # pragma: no cover + tmp_dir = path.parent / f"tmpsavedir_{generate_random_ASCII()}" + if tmp_dir.is_dir(): # pragma: no cover raise ValueError( - f"temporary intermediate directory <{save_dir}> already exits" + f"temporary intermediate directory <{tmp_dir}> already exits" ) - save_dir.mkdir(parents=True) + tmp_dir.mkdir(parents=True) - self._save_to_dir_(save_dir) + self._save_to_dir_(tmp_dir) - # Then : tar dir in file + # Then : tar dir in file # TODO: avoid using subprocess by using lib tarfile - ARGUMENTS = ["tar", "-cf", fname, "-C", save_dir, "."] + ARGUMENTS = ["tar", "-cf", path, "-C", tmp_dir, "."] subprocess.call(ARGUMENTS) - # Finally : removes directory - shutil.rmtree(save_dir) + # Finally : removes directory + shutil.rmtree(tmp_dir) def summarize_features(self) -> str: """Show the name of each feature and the number of samples containing it. @@ -1274,27 +1285,27 @@ def from_list_of_samples( @classmethod def load_from_file( - cls, fname: Union[str, Path], verbose: bool = False, processes_number: int = 0 + cls, path: Union[str, Path], verbose: bool = False, processes_number: int = 0 ) -> Self: """Load data from a specified TAR (Tape Archive) file. Args: - fname (Union[str,Path]): The path to the data file to be loaded. + path (Union[str,Path]): The path to the data file to be loaded. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. processes_number (int, optional): Number of processes used to load files (-1 to use all available ressources, 0 to disable multiprocessing). Defaults to 0. Returns: Self: The loaded dataset (Dataset). """ - fname = Path(fname) + path = Path(path) instance = cls() - instance.load(fname, verbose, processes_number) + instance.load(path, verbose, processes_number) return instance @classmethod def load_from_dir( cls, - dname: Union[str, Path], + path: Union[str, Path], ids: Optional[list[int]] = None, verbose: bool = False, processes_number: int = 0, @@ -1302,7 +1313,7 @@ def load_from_dir( """Load data from a specified directory. Args: - dname (Union[str,Path]): The path from which to load files. + path (Union[str,Path]): The path from which to load files. ids (list, optional): The specific sample IDs to load from the dataset. Defaults to None. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. processes_number (int, optional): Number of processes used to load files (-1 to use all available ressources, 0 to disable multiprocessing). Defaults to 0. @@ -1310,22 +1321,22 @@ def load_from_dir( Returns: Self: The loaded dataset (Dataset). """ - dname = Path(dname) + path = Path(path) instance = cls() instance._load_from_dir_( - dname, ids=ids, verbose=verbose, processes_number=processes_number + path, ids=ids, verbose=verbose, processes_number=processes_number ) return instance def load( - self, fname: Union[str, Path], verbose: bool = False, processes_number: int = 0 + self, path: Union[str, Path], verbose: bool = False, processes_number: int = 0 ) -> None: """Load data from a specified TAR (Tape Archive) file. It creates a temporary intermediate directory to store temporary files during the loading process. Args: - fname (Union[str,Path]): The path to the data file to be loaded. + path (Union[str,Path]): The path to the data file to be loaded. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. processes_number (int, optional): Number of processes used to load files (-1 to use all available ressources, 0 to disable multiprocessing). Defaults to 0. @@ -1333,18 +1344,18 @@ def load( ValueError: If a randomly generated temporary directory already exists, indicating a potential conflict during the loading process (extremely unlikely). """ - fname = Path(fname) + path = Path(path) - inputdir = fname.parent / f"tmploaddir_{generate_random_ASCII()}" + inputdir = path.parent / f"tmploaddir_{generate_random_ASCII()}" if inputdir.is_dir(): # pragma: no cover raise ValueError( f"temporary intermediate directory <{inputdir}> already exits" ) inputdir.mkdir(parents=True) - # First : untar file to a directory + # First : untar file to a directory # TODO: avoid using subprocess by using a lib tarfile - arguments = ["tar", "-xf", fname, "-C", inputdir] + arguments = ["tar", "-xf", path, "-C", inputdir] subprocess.call(arguments) # Then : load data from directory @@ -1359,44 +1370,57 @@ def load( def add_to_dir( self, sample: Sample, + path: Optional[Union[str, Path]] = None, save_dir: Optional[Union[str, Path]] = None, verbose: bool = False, ) -> None: """Add a sample to the dataset and save it to the specified directory. Notes: - If `save_dir` is None, will look for `self.save_dir` which will be retrieved from last previous call to load or save. - `save_dir` given in argument will take precedence over `self.save_dir` and overwrite it. + If `path` is None, will look for `self.path` which will be retrieved from last previous call to load or save. + `path` given in argument will take precedence over `self.path` and overwrite it. Args: sample (Sample): The sample to add. - save_dir (Union[str,Path], optional): The directory in which to save the sample. Defaults to None. + path (Union[str,Path], optional): The directory in which to save the sample. Defaults to None. + save_dir (Union[str,Path], optional): Deprecated, use `path` instead. verbose (bool, optional): If True, will print additional information. Defaults to False. Raises: - ValueError: If both self.save_dir and save_dir are None. + ValueError: If both self.path and path are None. """ if save_dir is not None: - save_dir = Path(save_dir) - self.save_dir = save_dir + if path is not None: + raise ValueError( + "Arguments `path` and `save_dir` cannot be both set. Use only `path` as `save_dir` is deprecated." + ) + else: + path = save_dir + logger.warning( + "DeprecationWarning: 'save_dir' is deprecated, use 'path' instead." + ) + + if path is not None: + path = Path(path) + self.path = path else: - if not hasattr(self, "save_dir") or self.save_dir is None: + if not hasattr(self, "path") or self.path is None: raise ValueError( - "self.save_dir and save_dir are None, we don't know where to save, specify one of them before" + "self.path and path are None, we don't know where to save, specify one of them before" ) # --- sample is not only saved to dir, but also added to the dataset # self.add_sample(sample) - # --- if dataset already contains other Samples, they will all be saved to save_dir - # self._save_to_dir_(self.save_dir) + # --- if dataset already contains other Samples, they will all be saved to path + # self._save_to_dir_(self.path) - if not self.save_dir.is_dir(): - self.save_dir.mkdir(parents=True) + if not self.path.is_dir(): + self.path.mkdir(parents=True) if verbose: - print(f"Saving database to: {self.save_dir}") + print(f"Saving database to: {self.path}") - samples_dir = self.save_dir / "samples" + samples_dir = self.path / "samples" if not samples_dir.is_dir(): samples_dir.mkdir(parents=True) @@ -1414,23 +1438,23 @@ def add_to_dir( sample_fname = samples_dir / f"sample_{i_sample:09d}" sample.save(sample_fname) - def _save_to_dir_(self, save_dir: Union[str, Path], verbose: bool = False) -> None: + def _save_to_dir_(self, path: Union[str, Path], verbose: bool = False) -> None: """Saves the dataset into a sub-directory `samples` and creates an 'infos.yaml' file to store additional information about the dataset. Args: - save_dir (Union[str,Path]): The path in which to save the files. + path (Union[str,Path]): The path in which to save the files. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. """ - save_dir = Path(save_dir) - if not (save_dir.is_dir()): - save_dir.mkdir(parents=True) + path = Path(path) + if not (path.is_dir()): + path.mkdir(parents=True) - self.save_dir = save_dir + self.path = path if verbose: # pragma: no cover - print(f"Saving database to: {save_dir}") + print(f"Saving database to: {path}") - samples_dir = save_dir / "samples" + samples_dir = path / "samples" if not (samples_dir.is_dir()): samples_dir.mkdir(parents=True) @@ -1441,21 +1465,21 @@ def _save_to_dir_(self, save_dir: Union[str, Path], verbose: bool = False) -> No # ---# save infos if len(self._infos) > 0: - infos_fname = save_dir / "infos.yaml" + infos_fname = path / "infos.yaml" with open(infos_fname, "w") as file: yaml.dump(self._infos, file, default_flow_style=False, sort_keys=False) # #---# save stats - # stats_fname = save_dir / 'stats.yaml' + # stats_fname = path / 'stats.yaml' # self._stats.save(stats_fname) # #---# save flags - # flags_fname = save_dir / 'flags.yaml' + # flags_fname = path / 'flags.yaml' # self._flags.save(flags_fname) def _load_from_dir_( self, - save_dir: Union[str, Path], + path: Union[str, Path], ids: Optional[list[int]] = None, verbose: bool = False, processes_number: int = 0, @@ -1463,7 +1487,7 @@ def _load_from_dir_( """Loads a dataset from a sample directory and retrieves additional information about the dataset from an 'infos.yaml' file, if available. Args: - save_dir (Union[str,Path]): The path from which to load files. + path (Union[str,Path]): The path from which to load files. ids (list, optional): The specific sample IDs to load from the dataset. Defaults to None. verbose (bool, optional): Explicitly displays the operations performed. Defaults to False. processes_number (int, optional): Number of processes used to load files (-1 to use all available ressources, 0 to disable multiprocessing). Defaults to 0. @@ -1473,22 +1497,22 @@ def _load_from_dir_( FileExistsError: Triggered if the provided path is a file instead of a directory. ValueError: Triggered if the number of processes is < -1. """ - save_dir = Path(save_dir) - if not save_dir.is_dir(): + path = Path(path) + if not path.is_dir(): raise FileNotFoundError( - f'"{save_dir}" is not a directory or does not exist. Abort' + f'"{path}" is not a directory or does not exist. Abort' ) if processes_number < -1: raise ValueError("Number of processes cannot be < -1") - self.save_dir = save_dir + self.path = path if verbose: # pragma: no cover - print(f"Reading database located at: {save_dir}") + print(f"Reading database located at: {path}") sample_paths = sorted( - [path for path in (save_dir / "samples").glob("sample_*") if path.is_dir()] + [path for path in (path / "samples").glob("sample_*") if path.is_dir()] ) if ids is not None: @@ -1550,7 +1574,7 @@ def update(self, *a): self.set_sample(id, sample) """ - infos_fname = save_dir / "infos.yaml" + infos_fname = path / "infos.yaml" if infos_fname.is_file(): with open(infos_fname, "r") as file: self._infos = yaml.safe_load(file) @@ -1559,14 +1583,14 @@ def update(self, *a): print("Warning: dataset contains no sample") @staticmethod - def _load_number_of_samples_(_savedir: Union[str, Path]) -> int: # pragma: no cover + def _load_number_of_samples_(_path: Union[str, Path]) -> int: """Warning: This method is deprecated, use instead :meth:`plaid.get_number_of_samples `. This function counts the number of sample files in a specified directory, which is useful for determining the total number of samples in a dataset. Args: - save_dir (Union[str,Path]): The path to the directory where sample files are stored. + path (Union[str,Path]): The path to the directory where sample files are stored. Returns: int: The number of sample files found in the specified directory. diff --git a/src/plaid/containers/sample.py b/src/plaid/containers/sample.py index 96f4e53e..1dd17ca0 100644 --- a/src/plaid/containers/sample.py +++ b/src/plaid/containers/sample.py @@ -141,6 +141,7 @@ class Sample(BaseModel): def __init__( self, + path: Optional[Union[str, Path]] = None, directory_path: Optional[Union[str, Path]] = None, mesh_base_name: str = "Base", mesh_zone_name: str = "Zone", @@ -153,7 +154,8 @@ def __init__( """Initialize an empty :class:`Sample `. Args: - directory_path (Union[str, Path], optional): The path from which to load PLAID sample files. + path (Union[str,Path], optional): The path from which to load PLAID sample files. + directory_path (Union[str,Path], optional): Deprecated, use `path` instead. mesh_base_name (str, optional): The base name for the mesh. Defaults to 'Base'. mesh_zone_name (str, optional): The zone name for the mesh. Defaults to 'Zone'. meshes (dict[float, CGNSTree], optional): A dictionary mapping time steps to CGNSTrees. Defaults to None. @@ -193,8 +195,19 @@ def __init__( self._paths: Optional[dict[float, list[CGNSPath]]] = paths if directory_path is not None: - directory_path = Path(directory_path) - self.load(directory_path) + if path is not None: + raise ValueError( + "Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated." + ) + else: + path = directory_path + logger.warning( + "DeprecationWarning: 'directory_path' is deprecated, use 'path' instead." + ) + + if path is not None: + path = Path(path) + self.load(path) self._defaults: dict[str, Optional[str]] = { "active_base": None, @@ -703,7 +716,7 @@ def link_tree( """Link the geometrical features of the CGNS tree of the current sample at a given time, to the ones of another sample. Args: - path_linked_sample (Union[str, Path]): The absolute path of the folder containing the linked CGNS + path_linked_sample (Union[str,Path]): The absolute path of the folder containing the linked CGNS linked_sample (Sample): The linked sample linked_time (float): The time step of the linked CGNS in the linked sample time (float): The time step the current sample to which the CGNS tree is linked. @@ -2110,27 +2123,27 @@ def merge_features(self, sample: Self, in_place: bool = False) -> Self: ) # -------------------------------------------------------------------------# - def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None: - """Save the Sample in directory `dir_path`. + def save(self, path: Union[str, Path], overwrite: bool = False) -> None: + """Save the Sample in directory `path`. Args: - dir_path (Union[str,Path]): relative or absolute directory path. + path (Union[str,Path]): relative or absolute directory path. overwrite (bool): target directory overwritten if True. """ - dir_path = Path(dir_path) + path = Path(path) - if dir_path.is_dir(): + if path.is_dir(): if overwrite: - shutil.rmtree(dir_path) - logger.warning(f"Existing {dir_path} directory has been reset.") - elif len(list(dir_path.glob("*"))): + shutil.rmtree(path) + logger.warning(f"Existing {path} directory has been reset.") + elif len(list(path.glob("*"))): raise ValueError( - f"directory {dir_path} already exists and is not empty. Set `overwrite` to True if needed." + f"directory {path} already exists and is not empty. Set `overwrite` to True if needed." ) - dir_path.mkdir(exist_ok=True) + path.mkdir(exist_ok=True) - mesh_dir = dir_path / "meshes" + mesh_dir = path / "meshes" if self._meshes is not None: mesh_dir.mkdir() @@ -2149,7 +2162,7 @@ def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None: scalars = np.array(scalars).reshape((1, -1)) header = ",".join(scalars_names) np.savetxt( - dir_path / "scalars.csv", + path / "scalars.csv", scalars, header=header, delimiter=",", @@ -2163,7 +2176,7 @@ def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None: data = np.vstack((ts[0], ts[1])).T header = ",".join(["t", ts_name]) np.savetxt( - dir_path / f"time_series_{ts_name}.csv", + path / f"time_series_{ts_name}.csv", data, header=header, delimiter=",", @@ -2171,13 +2184,13 @@ def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None: ) @classmethod - def load_from_dir(cls, dir_path: Union[str, Path]) -> Self: - """Load the Sample from directory `dir_path`. + def load_from_dir(cls, path: Union[str, Path]) -> Self: + """Load the Sample from directory `path`. This is a class method, you don't need to instantiate a `Sample` first. Args: - dir_path (Union[str,Path]): Relative or absolute directory path. + path (Union[str,Path]): Relative or absolute directory path. Returns: Sample @@ -2193,16 +2206,16 @@ def load_from_dir(cls, dir_path: Union[str, Path]) -> Self: Note: It calls 'load' function during execution. """ - dir_path = Path(dir_path) + path = Path(path) instance = cls() - instance.load(dir_path) + instance.load(path) return instance - def load(self, dir_path: Union[str, Path]) -> None: - """Load the Sample from directory `dir_path`. + def load(self, path: Union[str, Path]) -> None: + """Load the Sample from directory `path`. Args: - dir_path (Union[str,Path]): Relative or absolute directory path. + path (Union[str,Path]): Relative or absolute directory path. Raises: FileNotFoundError: Triggered if the provided directory does not exist. @@ -2213,20 +2226,20 @@ def load(self, dir_path: Union[str, Path]) -> None: from plaid import Sample sample = Sample() - sample.load(dir_path) + sample.load(path) print(sample) >>> Sample(3 scalars, 1 timestamp, 3 fields) """ - dir_path = Path(dir_path) + path = Path(path) - if not dir_path.exists(): - raise FileNotFoundError(f'Directory "{dir_path}" does not exist. Abort') + if not path.exists(): + raise FileNotFoundError(f'Directory "{path}" does not exist. Abort') - if not dir_path.is_dir(): - raise FileExistsError(f'"{dir_path}" is not a directory. Abort') + if not path.is_dir(): + raise FileExistsError(f'"{path}" is not a directory. Abort') - meshes_dir = dir_path / "meshes" + meshes_dir = path / "meshes" if meshes_dir.is_dir(): meshes_names = list(meshes_dir.glob("*")) nb_meshes = len(meshes_names) @@ -2245,7 +2258,7 @@ def load(self, dir_path: Union[str, Path]) -> None: for i in range(len(self._links[time])): # pragma: no cover self._links[time][i][0] = str(meshes_dir / self._links[time][i][0]) - scalars_fname = dir_path / "scalars.csv" + scalars_fname = path / "scalars.csv" if scalars_fname.is_file(): names = np.loadtxt( scalars_fname, dtype=str, max_rows=1, delimiter="," @@ -2256,7 +2269,7 @@ def load(self, dir_path: Union[str, Path]) -> None: for name, value in zip(names, scalars): self.add_scalar(name, value) - time_series_files = list(dir_path.glob("time_series_*.csv")) + time_series_files = list(path.glob("time_series_*.csv")) for ts_fname in time_series_files: names = np.loadtxt(ts_fname, dtype=str, max_rows=1, delimiter=",").reshape( (-1,) diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index 5148e5ab..5e11aab8 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -21,7 +21,7 @@ import csv import logging from pathlib import Path -from typing import Union +from typing import Optional, Union import yaml @@ -44,13 +44,18 @@ class ProblemDefinition(object): """Gathers all necessary informations to define a learning problem.""" - def __init__(self, directory_path: Union[str, Path] = None) -> None: + def __init__( + self, + path: Optional[Union[str, Path]] = None, + directory_path: Optional[Union[str, Path]] = None, + ) -> None: """Initialize an empty :class:`ProblemDefinition `. Use :meth:`add_inputs ` or :meth:`add_output_scalars_names ` to feed the :class:`ProblemDefinition` Args: - directory_path (Union[str, Path], optional): The path from which to load PLAID problem definition files. + path (Union[str,Path], optional): The path from which to load PLAID problem definition files. + directory_path (Union[str,Path], optional): Deprecated, use `path` instead. Example: .. code-block:: python @@ -79,8 +84,19 @@ def __init__(self, directory_path: Union[str, Path] = None) -> None: self._split: dict[str, IndexType] = None if directory_path is not None: - directory_path = Path(directory_path) - self._load_from_dir_(directory_path) + if path is not None: + raise ValueError( + "Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated." + ) + else: + path = directory_path + logger.warning( + "DeprecationWarning: 'directory_path' is deprecated, use 'path' instead." + ) + + if path is not None: + path = Path(path) + self._load_from_dir_(path) # -------------------------------------------------------------------------# def get_task(self) -> str: @@ -890,11 +906,11 @@ def get_all_indices(self) -> list[int]: # return res # -------------------------------------------------------------------------# - def _save_to_dir_(self, savedir: Path) -> None: + def _save_to_dir_(self, path: Union[str, Path]) -> None: """Save problem information, inputs, outputs, and split to the specified directory in YAML and CSV formats. Args: - savedir (Path): The directory where the problem information will be saved. + path (Union[str,Path]): The directory where the problem information will be saved. Example: .. code-block:: python @@ -903,8 +919,10 @@ def _save_to_dir_(self, savedir: Path) -> None: problem = ProblemDefinition() problem._save_to_dir_("/path/to/save_directory") """ - if not (savedir.is_dir()): # pragma: no cover - savedir.mkdir() + path = Path(path) + + if not (path.is_dir()): + path.mkdir() data = { "task": self._task, @@ -918,11 +936,11 @@ def _save_to_dir_(self, savedir: Path) -> None: "output_meshes": self.out_meshes_names, # list[output mesh name] } - pbdef_fname = savedir / "problem_infos.yaml" + pbdef_fname = path / "problem_infos.yaml" with open(pbdef_fname, "w") as file: yaml.dump(data, file, default_flow_style=False, sort_keys=False) - split_fname = savedir / "split.csv" + split_fname = path / "split.csv" if self._split is not None: with open(split_fname, "w", newline="") as file: write = csv.writer(file) @@ -930,27 +948,27 @@ def _save_to_dir_(self, savedir: Path) -> None: write.writerow([name] + list(indices)) @classmethod - def load(cls, save_dir: str) -> Self: # pragma: no cover + def load(cls, path: Union[str, Path]) -> Self: # pragma: no cover """Load data from a specified directory. Args: - save_dir (str): The path from which to load files. + path (Union[str,Path]): The path from which to load files. Returns: Self: The loaded dataset (Dataset). """ instance = cls() - instance._load_from_dir_(save_dir) + instance._load_from_dir_(path) return instance - def _load_from_dir_(self, save_dir: Path) -> None: + def _load_from_dir_(self, path: Union[str, Path]) -> None: """Load problem information, inputs, outputs, and split from the specified directory in YAML and CSV formats. Args: - save_dir (Path): The directory from which to load the problem information. + path (Union[str,Path]): The directory from which to load the problem information. Raises: - FileNotFoundError: Triggered if the provided directory does not exist. + FileNotFoundError: Triggered if the provided directory or file problem_infos.yaml does not exist FileExistsError: Triggered if the provided path is a file instead of a directory. Example: @@ -960,20 +978,22 @@ def _load_from_dir_(self, save_dir: Path) -> None: problem = ProblemDefinition() problem._load_from_dir_("/path/to/load_directory") """ - if not save_dir.exists(): # pragma: no cover - raise FileNotFoundError(f'Directory "{save_dir}" does not exist. Abort') + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f'Directory "{path}" does not exist. Abort') - if not save_dir.is_dir(): # pragma: no cover - raise FileExistsError(f'"{save_dir}" is not a directory. Abort') + if not path.is_dir(): + raise FileExistsError(f'"{path}" is not a directory. Abort') - pbdef_fname = save_dir / "problem_infos.yaml" + pbdef_fname = path / "problem_infos.yaml" data = {} # To avoid crash if pbdef_fname does not exist if pbdef_fname.is_file(): with open(pbdef_fname, "r") as file: data = yaml.safe_load(file) - else: # pragma: no cover - logger.warning( - f"file with path `{pbdef_fname}` does not exist. Task, inputs, and outputs will not be set" + else: + raise FileNotFoundError( + f"file with path `{pbdef_fname}` does not exist. Abort" ) self._task = data["task"] @@ -986,7 +1006,7 @@ def _load_from_dir_(self, save_dir: Path) -> None: self.in_meshes_names = data["input_meshes"] self.out_meshes_names = data["output_meshes"] - split_fname = save_dir / "split.csv" + split_fname = path / "split.csv" split = {} if split_fname.is_file(): with open(split_fname) as file: diff --git a/tests/containers/test_dataset.py b/tests/containers/test_dataset.py index 3e09d465..6f766fb2 100644 --- a/tests/containers/test_dataset.py +++ b/tests/containers/test_dataset.py @@ -16,7 +16,7 @@ import plaid from plaid.containers.dataset import Dataset from plaid.containers.sample import Sample -from plaid.utils.base import ShapeError +from plaid.utils.base import DeprecatedError, ShapeError # %% Fixtures @@ -79,6 +79,19 @@ def test___init__file_provided(self, current_directory): with pytest.raises(FileNotFoundError): Dataset(dataset_path) + def test__init__path(self, current_directory): + dataset_path = current_directory / "dataset" + Dataset(path=dataset_path) + + def test__init__directory_path(self, current_directory): + dataset_path = current_directory / "dataset" + Dataset(directory_path=dataset_path) + + def test__init__both_path_and_directory_path(self, current_directory): + dataset_path = current_directory / "dataset" + with pytest.raises(ValueError): + Dataset(path=dataset_path, directory_path=dataset_path) + # -------------------------------------------------------------------------# def test_get_samples(self, dataset_with_samples, nb_samples): dataset_with_samples.get_samples() @@ -991,6 +1004,21 @@ def test_add_to_dir_verbose(self, empty_dataset, sample, tmp_path, capsys): captured = capsys.readouterr() assert "Saving database to" in captured.out + def test__add_to_dir__path(self, empty_dataset, sample, current_directory): + save_dir = current_directory / "my_dataset_dir" + empty_dataset.add_to_dir(sample, path=save_dir) + + def test__add_to_dir__save_dir(self, empty_dataset, sample, current_directory): + save_dir = current_directory / "my_dataset_dir" + empty_dataset.add_to_dir(sample, save_dir=save_dir) + + def test__add_to_dir__both_path_and_save_dir( + self, empty_dataset, sample, current_directory + ): + save_dir = current_directory / "my_dataset_dir" + with pytest.raises(ValueError): + empty_dataset.add_to_dir(sample, path=save_dir, save_dir=save_dir) + # -------------------------------------------------------------------------# def test__save_to_dir_(self, dataset_with_samples, tmp_path): savedir = tmp_path / "testdir" @@ -1013,6 +1041,11 @@ def test__load_from_dir_(self, dataset_with_samples, infos, tmp_path): new_dataset._load_from_dir_(savedir, [1, 2]) assert len(new_dataset) == 2 + # -------------------------------------------------------------------------# + def test__load_number_of_samples_(self, tmp_path): + with pytest.raises(DeprecatedError): + Dataset._load_number_of_samples_(tmp_path) + # -------------------------------------------------------------------------# def test_set_samples(self, dataset, samples): dataset.set_samples({i: samp for i, samp in enumerate(samples)}) diff --git a/tests/containers/test_sample.py b/tests/containers/test_sample.py index 969d2e75..ae0d9ac9 100644 --- a/tests/containers/test_sample.py +++ b/tests/containers/test_sample.py @@ -150,12 +150,12 @@ def current_directory() -> Path: class Test_Sample: # -------------------------------------------------------------------------# def test___init__(self, current_directory): - dataset_path_1 = current_directory / "dataset" / "samples" / "sample_000000000" - dataset_path_2 = current_directory / "dataset" / "samples" / "sample_000000001" - dataset_path_3 = current_directory / "dataset" / "samples" / "sample_000000002" - sample_already_filled_1 = Sample(dataset_path_1) - sample_already_filled_2 = Sample(dataset_path_2) - sample_already_filled_3 = Sample(dataset_path_3) + sample_path_1 = current_directory / "dataset" / "samples" / "sample_000000000" + sample_path_2 = current_directory / "dataset" / "samples" / "sample_000000001" + sample_path_3 = current_directory / "dataset" / "samples" / "sample_000000002" + sample_already_filled_1 = Sample(sample_path_1) + sample_already_filled_2 = Sample(sample_path_2) + sample_already_filled_3 = Sample(sample_path_3) assert ( sample_already_filled_1._meshes is not None and sample_already_filled_1._scalars is not None @@ -170,14 +170,27 @@ def test___init__(self, current_directory): ) def test__init__unknown_directory(self, current_directory): - dataset_path = current_directory / "dataset" / "samples" / "sample_000000298" + sample_path = current_directory / "dataset" / "samples" / "sample_000000298" with pytest.raises(FileNotFoundError): - Sample(dataset_path) + Sample(sample_path) def test__init__file_provided(self, current_directory): - dataset_path = current_directory / "dataset" / "samples" / "sample_000067392" + sample_path = current_directory / "dataset" / "samples" / "sample_000067392" with pytest.raises(FileExistsError): - Sample(dataset_path) + Sample(sample_path) + + def test__init__path(self, current_directory): + sample_path = current_directory / "dataset" / "samples" / "sample_000000000" + Sample(path=sample_path) + + def test__init__directory_path(self, current_directory): + sample_path = current_directory / "dataset" / "samples" / "sample_000000000" + Sample(directory_path=sample_path) + + def test__init__both_path_and_directory_path(self, current_directory): + sample_path = current_directory / "dataset" / "samples" / "sample_000000000" + with pytest.raises(ValueError): + Sample(path=sample_path, directory_path=sample_path) def test_copy(self, sample_with_tree_and_scalar_and_time_series): sample_with_tree_and_scalar_and_time_series.copy() diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index 4c984e35..72c42668 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -50,6 +50,19 @@ def test__init__(self, problem_definition): assert problem_definition.get_task() is None print(problem_definition) + def test__init__path(self, current_directory): + d_path = current_directory / "problem_definition" + ProblemDefinition(path=d_path) + + def test__init__directory_path(self, current_directory): + d_path = current_directory / "problem_definition" + ProblemDefinition(directory_path=d_path) + + def test__init__both_path_and_directory_path(self, current_directory): + d_path = current_directory / "problem_definition" + with pytest.raises(ValueError): + ProblemDefinition(path=d_path, directory_path=d_path) + # -------------------------------------------------------------------------# def test_task(self, problem_definition): # Unauthorized task @@ -370,6 +383,9 @@ def test_save(self, problem_definition, current_directory): problem_definition._save_to_dir_(current_directory / "problem_definition") + def test__save_to_dir_(self, problem_definition, tmp_path): + problem_definition._save_to_dir_(tmp_path / "problem_definition") + def test_load_path_object(self, current_directory): from pathlib import Path @@ -411,3 +427,21 @@ def test_load(self, current_directory): ) all_split = problem.get_split() assert all_split["train"] == [0, 1, 2] and all_split["test"] == [3, 4] + + def test__load_from_dir__empty_dir(self, tmp_path): + problem = ProblemDefinition() + with pytest.raises(FileNotFoundError): + problem._load_from_dir_(tmp_path) + + def test__load_from_dir__non_existing_dir(self): + problem = ProblemDefinition() + non_existing_dir = Path("non_existing_path") + with pytest.raises(FileNotFoundError): + problem._load_from_dir_(non_existing_dir) + + def test__load_from_dir__path_is_file(self, tmp_path): + problem = ProblemDefinition() + file_path = tmp_path / "file.yaml" + file_path.touch() # Create an empty file + with pytest.raises(FileExistsError): + problem._load_from_dir_(file_path)