Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixes

- (plaid/examples) fix circular imports
- (sample/dataset/problem_definition) fix incoherent path argument names in save/load methods -> `path` is now used everywhere

### Removed

Expand Down
176 changes: 100 additions & 76 deletions src/plaid/containers/dataset.py

Large diffs are not rendered by default.

81 changes: 47 additions & 34 deletions src/plaid/containers/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class Sample(BaseModel):

def __init__(
self,
path: Optional[Union[str, Path]] = None,
directory_path: Optional[Union[str, Path]] = None,
mesh_base_name: str = "Base",
mesh_zone_name: str = "Zone",
Expand All @@ -153,7 +154,8 @@ def __init__(
"""Initialize an empty :class:`Sample <plaid.containers.sample.Sample>`.

Args:
directory_path (Union[str, Path], optional): The path from which to load PLAID sample files.
path (Union[str,Path], optional): The path from which to load PLAID sample files.
directory_path (Union[str,Path], optional): Deprecated, use `path` instead.
mesh_base_name (str, optional): The base name for the mesh. Defaults to 'Base'.
mesh_zone_name (str, optional): The zone name for the mesh. Defaults to 'Zone'.
meshes (dict[float, CGNSTree], optional): A dictionary mapping time steps to CGNSTrees. Defaults to None.
Expand Down Expand Up @@ -193,8 +195,19 @@ def __init__(
self._paths: Optional[dict[float, list[CGNSPath]]] = paths

if directory_path is not None:
directory_path = Path(directory_path)
self.load(directory_path)
if path is not None:
raise ValueError(
"Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated."
)
else:
path = directory_path
logger.warning(
"DeprecationWarning: 'directory_path' is deprecated, use 'path' instead."
)

if path is not None:
path = Path(path)
self.load(path)

self._defaults: dict[str, Optional[str]] = {
"active_base": None,
Expand Down Expand Up @@ -703,7 +716,7 @@ def link_tree(
"""Link the geometrical features of the CGNS tree of the current sample at a given time, to the ones of another sample.

Args:
path_linked_sample (Union[str, Path]): The absolute path of the folder containing the linked CGNS
path_linked_sample (Union[str,Path]): The absolute path of the folder containing the linked CGNS
linked_sample (Sample): The linked sample
linked_time (float): The time step of the linked CGNS in the linked sample
time (float): The time step the current sample to which the CGNS tree is linked.
Expand Down Expand Up @@ -2110,27 +2123,27 @@ def merge_features(self, sample: Self, in_place: bool = False) -> Self:
)

# -------------------------------------------------------------------------#
def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None:
"""Save the Sample in directory `dir_path`.
def save(self, path: Union[str, Path], overwrite: bool = False) -> None:
"""Save the Sample in directory `path`.

Args:
dir_path (Union[str,Path]): relative or absolute directory path.
path (Union[str,Path]): relative or absolute directory path.
overwrite (bool): target directory overwritten if True.
"""
dir_path = Path(dir_path)
path = Path(path)

if dir_path.is_dir():
if path.is_dir():
if overwrite:
shutil.rmtree(dir_path)
logger.warning(f"Existing {dir_path} directory has been reset.")
elif len(list(dir_path.glob("*"))):
shutil.rmtree(path)
logger.warning(f"Existing {path} directory has been reset.")
elif len(list(path.glob("*"))):
raise ValueError(
f"directory {dir_path} already exists and is not empty. Set `overwrite` to True if needed."
f"directory {path} already exists and is not empty. Set `overwrite` to True if needed."
)

dir_path.mkdir(exist_ok=True)
path.mkdir(exist_ok=True)

mesh_dir = dir_path / "meshes"
mesh_dir = path / "meshes"

if self._meshes is not None:
mesh_dir.mkdir()
Expand All @@ -2149,7 +2162,7 @@ def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None:
scalars = np.array(scalars).reshape((1, -1))
header = ",".join(scalars_names)
np.savetxt(
dir_path / "scalars.csv",
path / "scalars.csv",
scalars,
header=header,
delimiter=",",
Expand All @@ -2163,21 +2176,21 @@ def save(self, dir_path: Union[str, Path], overwrite: bool = False) -> None:
data = np.vstack((ts[0], ts[1])).T
header = ",".join(["t", ts_name])
np.savetxt(
dir_path / f"time_series_{ts_name}.csv",
path / f"time_series_{ts_name}.csv",
data,
header=header,
delimiter=",",
comments="",
)

@classmethod
def load_from_dir(cls, dir_path: Union[str, Path]) -> Self:
"""Load the Sample from directory `dir_path`.
def load_from_dir(cls, path: Union[str, Path]) -> Self:
"""Load the Sample from directory `path`.

This is a class method, you don't need to instantiate a `Sample` first.

Args:
dir_path (Union[str,Path]): Relative or absolute directory path.
path (Union[str,Path]): Relative or absolute directory path.

Returns:
Sample
Expand All @@ -2193,16 +2206,16 @@ def load_from_dir(cls, dir_path: Union[str, Path]) -> Self:
Note:
It calls 'load' function during execution.
"""
dir_path = Path(dir_path)
path = Path(path)
instance = cls()
instance.load(dir_path)
instance.load(path)
return instance

def load(self, dir_path: Union[str, Path]) -> None:
"""Load the Sample from directory `dir_path`.
def load(self, path: Union[str, Path]) -> None:
"""Load the Sample from directory `path`.

Args:
dir_path (Union[str,Path]): Relative or absolute directory path.
path (Union[str,Path]): Relative or absolute directory path.

Raises:
FileNotFoundError: Triggered if the provided directory does not exist.
Expand All @@ -2213,20 +2226,20 @@ def load(self, dir_path: Union[str, Path]) -> None:

from plaid import Sample
sample = Sample()
sample.load(dir_path)
sample.load(path)
print(sample)
>>> Sample(3 scalars, 1 timestamp, 3 fields)

"""
dir_path = Path(dir_path)
path = Path(path)

if not dir_path.exists():
raise FileNotFoundError(f'Directory "{dir_path}" does not exist. Abort')
if not path.exists():
raise FileNotFoundError(f'Directory "{path}" does not exist. Abort')

if not dir_path.is_dir():
raise FileExistsError(f'"{dir_path}" is not a directory. Abort')
if not path.is_dir():
raise FileExistsError(f'"{path}" is not a directory. Abort')

meshes_dir = dir_path / "meshes"
meshes_dir = path / "meshes"
if meshes_dir.is_dir():
meshes_names = list(meshes_dir.glob("*"))
nb_meshes = len(meshes_names)
Expand All @@ -2245,7 +2258,7 @@ def load(self, dir_path: Union[str, Path]) -> None:
for i in range(len(self._links[time])): # pragma: no cover
self._links[time][i][0] = str(meshes_dir / self._links[time][i][0])

scalars_fname = dir_path / "scalars.csv"
scalars_fname = path / "scalars.csv"
if scalars_fname.is_file():
names = np.loadtxt(
scalars_fname, dtype=str, max_rows=1, delimiter=","
Expand All @@ -2256,7 +2269,7 @@ def load(self, dir_path: Union[str, Path]) -> None:
for name, value in zip(names, scalars):
self.add_scalar(name, value)

time_series_files = list(dir_path.glob("time_series_*.csv"))
time_series_files = list(path.glob("time_series_*.csv"))
for ts_fname in time_series_files:
names = np.loadtxt(ts_fname, dtype=str, max_rows=1, delimiter=",").reshape(
(-1,)
Expand Down
72 changes: 46 additions & 26 deletions src/plaid/problem_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import csv
import logging
from pathlib import Path
from typing import Union
from typing import Optional, Union

import yaml

Expand All @@ -44,13 +44,18 @@
class ProblemDefinition(object):
"""Gathers all necessary informations to define a learning problem."""

def __init__(self, directory_path: Union[str, Path] = None) -> None:
def __init__(
self,
path: Optional[Union[str, Path]] = None,
directory_path: Optional[Union[str, Path]] = None,
) -> None:
"""Initialize an empty :class:`ProblemDefinition <plaid.problem_definition.ProblemDefinition>`.

Use :meth:`add_inputs <plaid.problem_definition.ProblemDefinition.add_inputs>` or :meth:`add_output_scalars_names <plaid.problem_definition.ProblemDefinition.add_output_scalars_names>` to feed the :class:`ProblemDefinition`

Args:
directory_path (Union[str, Path], optional): The path from which to load PLAID problem definition files.
path (Union[str,Path], optional): The path from which to load PLAID problem definition files.
directory_path (Union[str,Path], optional): Deprecated, use `path` instead.

Example:
.. code-block:: python
Expand Down Expand Up @@ -79,8 +84,19 @@ def __init__(self, directory_path: Union[str, Path] = None) -> None:
self._split: dict[str, IndexType] = None

if directory_path is not None:
directory_path = Path(directory_path)
self._load_from_dir_(directory_path)
if path is not None:
raise ValueError(
"Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated."
)
else:
path = directory_path
logger.warning(
"DeprecationWarning: 'directory_path' is deprecated, use 'path' instead."
)

if path is not None:
path = Path(path)
self._load_from_dir_(path)

# -------------------------------------------------------------------------#
def get_task(self) -> str:
Expand Down Expand Up @@ -890,11 +906,11 @@ def get_all_indices(self) -> list[int]:
# return res

# -------------------------------------------------------------------------#
def _save_to_dir_(self, savedir: Path) -> None:
def _save_to_dir_(self, path: Union[str, Path]) -> None:
"""Save problem information, inputs, outputs, and split to the specified directory in YAML and CSV formats.

Args:
savedir (Path): The directory where the problem information will be saved.
path (Union[str,Path]): The directory where the problem information will be saved.

Example:
.. code-block:: python
Expand All @@ -903,8 +919,10 @@ def _save_to_dir_(self, savedir: Path) -> None:
problem = ProblemDefinition()
problem._save_to_dir_("/path/to/save_directory")
"""
if not (savedir.is_dir()): # pragma: no cover
savedir.mkdir()
path = Path(path)

if not (path.is_dir()):
path.mkdir()

data = {
"task": self._task,
Expand All @@ -918,39 +936,39 @@ def _save_to_dir_(self, savedir: Path) -> None:
"output_meshes": self.out_meshes_names, # list[output mesh name]
}

pbdef_fname = savedir / "problem_infos.yaml"
pbdef_fname = path / "problem_infos.yaml"
with open(pbdef_fname, "w") as file:
yaml.dump(data, file, default_flow_style=False, sort_keys=False)

split_fname = savedir / "split.csv"
split_fname = path / "split.csv"
if self._split is not None:
with open(split_fname, "w", newline="") as file:
write = csv.writer(file)
for name, indices in self._split.items():
write.writerow([name] + list(indices))

@classmethod
def load(cls, save_dir: str) -> Self: # pragma: no cover
def load(cls, path: Union[str, Path]) -> Self: # pragma: no cover
"""Load data from a specified directory.

Args:
save_dir (str): The path from which to load files.
path (Union[str,Path]): The path from which to load files.

Returns:
Self: The loaded dataset (Dataset).
"""
instance = cls()
instance._load_from_dir_(save_dir)
instance._load_from_dir_(path)
return instance

def _load_from_dir_(self, save_dir: Path) -> None:
def _load_from_dir_(self, path: Union[str, Path]) -> None:
"""Load problem information, inputs, outputs, and split from the specified directory in YAML and CSV formats.

Args:
save_dir (Path): The directory from which to load the problem information.
path (Union[str,Path]): The directory from which to load the problem information.

Raises:
FileNotFoundError: Triggered if the provided directory does not exist.
FileNotFoundError: Triggered if the provided directory or file problem_infos.yaml does not exist
FileExistsError: Triggered if the provided path is a file instead of a directory.

Example:
Expand All @@ -960,20 +978,22 @@ def _load_from_dir_(self, save_dir: Path) -> None:
problem = ProblemDefinition()
problem._load_from_dir_("/path/to/load_directory")
"""
if not save_dir.exists(): # pragma: no cover
raise FileNotFoundError(f'Directory "{save_dir}" does not exist. Abort')
path = Path(path)

if not path.exists():
raise FileNotFoundError(f'Directory "{path}" does not exist. Abort')

if not save_dir.is_dir(): # pragma: no cover
raise FileExistsError(f'"{save_dir}" is not a directory. Abort')
if not path.is_dir():
raise FileExistsError(f'"{path}" is not a directory. Abort')

pbdef_fname = save_dir / "problem_infos.yaml"
pbdef_fname = path / "problem_infos.yaml"
data = {} # To avoid crash if pbdef_fname does not exist
if pbdef_fname.is_file():
with open(pbdef_fname, "r") as file:
data = yaml.safe_load(file)
else: # pragma: no cover
logger.warning(
f"file with path `{pbdef_fname}` does not exist. Task, inputs, and outputs will not be set"
else:
raise FileNotFoundError(
f"file with path `{pbdef_fname}` does not exist. Abort"
)

self._task = data["task"]
Expand All @@ -986,7 +1006,7 @@ def _load_from_dir_(self, save_dir: Path) -> None:
self.in_meshes_names = data["input_meshes"]
self.out_meshes_names = data["output_meshes"]

split_fname = save_dir / "split.csv"
split_fname = path / "split.csv"
split = {}
if split_fname.is_file():
with open(split_fname) as file:
Expand Down
Loading
Loading