diff --git a/hydrolib/core/io/bc/models.py b/hydrolib/core/io/bc/models.py index 39e8bf5c4..f76f93f38 100644 --- a/hydrolib/core/io/bc/models.py +++ b/hydrolib/core/io/bc/models.py @@ -11,19 +11,25 @@ import logging from enum import Enum from pathlib import Path -from typing import Callable, List, Literal, NamedTuple, Optional, Set, Union +from typing import Callable, Dict, List, Literal, Optional, Set, Union from pydantic import Extra from pydantic.class_validators import root_validator, validator from pydantic.fields import Field from hydrolib.core.io.ini.io_models import Property, Section -from hydrolib.core.io.ini.models import DataBlockINIBasedModel, INIGeneral, INIModel +from hydrolib.core.io.ini.models import ( + BaseModel, + DataBlockINIBasedModel, + INIGeneral, + INIModel, +) from hydrolib.core.io.ini.parser import Parser, ParserConfig from hydrolib.core.io.ini.serializer import SerializerConfig, write_ini from hydrolib.core.io.ini.util import ( get_enum_validator, get_from_subclass_defaults, + get_key_renaming_root_validator, get_split_string_on_delimiter_validator, make_list_validator, ) @@ -32,9 +38,18 @@ class VerticalInterpolation(str, Enum): + """Enum class containing the valid values for the vertical position type, + which defines what the numeric values for vertical position specification mean. + """ + linear = "linear" + """str: Linear interpolation between vertical positions.""" + log = "log" + """str: Logarithmic interpolation between vertical positions (e.g. vertical velocity profiles).""" + block = "block" + """str: Equal to the value at the directly lower specified vertical position.""" class VerticalPositionType(str, Enum): @@ -49,22 +64,40 @@ class VerticalPositionType(str, Enum): z_datum = "ZDatum" """str: z-coordinate with respect to the reference level of the model.""" + z_surf = "ZSurf" + """str: Absolute distance from the free surface downward.""" + class TimeInterpolation(str, Enum): + """Enum class containing the valid values for the time interpolation.""" + linear = "linear" + """str: Linear interpolation between times.""" + block_from = "blockFrom" + """str: Equal to that at the start of the time interval (latest specified time value).""" + block_to = "blockTo" + """str: Equal to that at the end of the time interval (upcoming specified time value).""" -class QuantityUnitPair(NamedTuple): - """A .bc file header lines tuple containing a quantity name and its unit.""" +class QuantityUnitPair(BaseModel): + """A .bc file header lines tuple containing a quantity name, its unit and optionally a vertical position index.""" quantity: str + """str: Name of quantity.""" + unit: str + """str: Unit of quantity.""" + + vertpositionindex: Optional[int] = Field(alias="vertPositionIndex") + """int (optional): This is a (one-based) index into the verticalposition-specification, assigning a vertical position to the quantity (t3D-blocks only).""" def _to_properties(self): yield Property(key="quantity", value=self.quantity) yield Property(key="unit", value=self.unit) + if self.vertpositionindex is not None: + yield Property(key="vertPositionIndex", value=self.vertpositionindex) class ForcingBase(DataBlockINIBasedModel): @@ -74,18 +107,17 @@ class ForcingBase(DataBlockINIBasedModel): Typically subclassed, for the specific types of forcing data, e.g, TimeSeries. This model is for example referenced under a [ForcingModel][hydrolib.core.io.bc.models.ForcingModel]`.forcing[..]`. - - Attributes: - name (str): Unique identifier that identifies the location for this forcing data. - function (str): Function type of the data in the actual datablock. - quantityunitpair (List[QuantityUnitPair]): list of header line tuples for one or - more quantities + their unit. Describes the columns in the actual datablock. """ _header: Literal["Forcing"] = "Forcing" name: str = Field(alias="name") + """str: Unique identifier that identifies the location for this forcing data.""" + function: str = Field(alias="function") + """str: Function type of the data in the actual datablock.""" + quantityunitpair: List[QuantityUnitPair] + """List[QuantityUnitPair]: List of header lines for one or more quantities and their unit. Describes the columns in the actual datablock.""" def _exclude_fields(self) -> Set: return {"quantityunitpair"}.union(super()._exclude_fields()) @@ -113,7 +145,9 @@ def _validate_quantityunitpair(cls, values): raise ValueError("unit is not provided") if isinstance(quantities, str) and isinstance(units, str): - values[quantityunitpairkey] = [(quantities, units)] + values[quantityunitpairkey] = [ + QuantityUnitPair(quantity=quantities, unit=units) + ] return values if isinstance(quantities, list) and isinstance(units, list): @@ -123,7 +157,8 @@ def _validate_quantityunitpair(cls, values): ) values[quantityunitpairkey] = [ - (quantity, unit) for quantity, unit in zip(quantities, units) + QuantityUnitPair(quantity=quantity, unit=unit) + for quantity, unit in zip(quantities, units) ] return values @@ -179,27 +214,43 @@ class TimeSeries(ForcingBase): """Subclass for a .bc file [Forcing] block with timeseries data.""" function: Literal["timeseries"] = "timeseries" + timeinterpolation: TimeInterpolation = Field(alias="timeInterpolation") + """TimeInterpolation: The type of time interpolation.""" + offset: float = Field(0.0, alias="offset") + """float: All values in the table are increased by the offset (after multiplication by factor). Defaults to 0.0.""" + factor: float = Field(1.0, alias="factor") + """float: All values in the table are multiplied with the factor. Defaults to 1.0.""" _timeinterpolation_validator = get_enum_validator( "timeinterpolation", enum=TimeInterpolation ) + _key_renaming_root_validator = get_key_renaming_root_validator( + { + "timeinterpolation": ["time_interpolation"], + } + ) + class Harmonic(ForcingBase): """Subclass for a .bc file [Forcing] block with harmonic components data.""" function: Literal["harmonic"] = "harmonic" + factor: float = Field(1.0, alias="factor") + """float: All values in the table are multiplied with the factor. Defaults to 1.0.""" class Astronomic(ForcingBase): """Subclass for a .bc file [Forcing] block with astronomic components data.""" function: Literal["astronomic"] = "astronomic" + factor: float = Field(1.0, alias="factor") + """float: All values in the table are multiplied with the factor. Defaults to 1.0.""" class HarmonicCorrection(ForcingBase): @@ -220,23 +271,174 @@ class T3D(ForcingBase): function: Literal["t3d"] = "t3d" offset: float = Field(0.0, alias="offset") + """float: All values in the table are increased by the offset (after multiplication by factor). Defaults to 0.0.""" + factor: float = Field(1.0, alias="factor") + """float: All values in the table are multiplied with the factor. Defaults to 1.0.""" - verticalpositions: List[float] = Field(alias="verticalPositions") - verticalinterpolation: VerticalInterpolation = Field(alias="verticalInterpolation") - verticalpositiontype: VerticalPositionType = Field(alias="verticalPositionType") + vertpositions: List[float] = Field(alias="vertPositions") + """List[float]: The specification of the vertical positions.""" + + vertinterpolation: VerticalInterpolation = Field( + VerticalInterpolation.linear.value, alias="vertInterpolation" + ) + """VerticalInterpolation: The type of vertical interpolation. Defaults to linear.""" + + vertpositiontype: VerticalPositionType = Field(alias="vertPositionType") + """VerticalPositionType: The vertical position type of the verticalpositions values.""" + + timeinterpolation: TimeInterpolation = Field( + TimeInterpolation.linear.value, alias="timeInterpolation" + ) + """TimeInterpolation: The type of time interpolation. Defaults to linear.""" + + _key_renaming_root_validator = get_key_renaming_root_validator( + { + "timeinterpolation": ["time_interpolation"], + "vertpositions": ["vertical_position_specification"], + "vertinterpolation": ["vertical_interpolation"], + "vertpositiontype": ["vertical_position_type"], + "vertpositionindex": ["vertical_position"], + } + ) _split_to_list = get_split_string_on_delimiter_validator( - "verticalpositions", + "vertpositions", ) _verticalinterpolation_validator = get_enum_validator( - "verticalinterpolation", enum=VerticalInterpolation + "vertinterpolation", enum=VerticalInterpolation ) _verticalpositiontype_validator = get_enum_validator( - "verticalpositiontype", enum=VerticalPositionType + "vertpositiontype", enum=VerticalPositionType + ) + _timeinterpolation_validator = get_enum_validator( + "timeinterpolation", enum=TimeInterpolation ) + @root_validator(pre=True) + def _validate_quantityunitpairs(cls, values: Dict) -> Dict: + super()._validate_quantityunitpair(values) + + quantityunitpairs = values["quantityunitpair"] + + T3D._validate_that_first_unit_is_time_and_has_no_verticalposition( + quantityunitpairs + ) + + verticalpositions = values.get("vertpositions") + if verticalpositions is None: + raise ValueError("vertPositions is not provided") + + number_of_verticalpositions = ( + len(verticalpositions) + if isinstance(verticalpositions, List) + else len(verticalpositions.split()) + ) + + verticalpositionindexes = values.get("vertpositionindex") + if verticalpositionindexes is None: + T3D._validate_that_all_non_time_quantityunitpairs_have_valid_verticalpositionindex( + quantityunitpairs, number_of_verticalpositions + ) + return values + + T3D._validate_verticalpositionindexes_and_update_quantityunitpairs( + verticalpositionindexes, + number_of_verticalpositions, + quantityunitpairs, + ) + + return values + + @staticmethod + def _validate_that_first_unit_is_time_and_has_no_verticalposition( + quantityunitpairs: List[QuantityUnitPair], + ) -> None: + if quantityunitpairs[0].quantity.lower() != "time": + raise ValueError("First quantity should be `time`") + if quantityunitpairs[0].vertpositionindex is not None: + raise ValueError("`time` quantity cannot have vertical position index") + + @staticmethod + def _validate_that_all_non_time_quantityunitpairs_have_valid_verticalpositionindex( + quantityunitpairs: List[QuantityUnitPair], maximum_verticalpositionindex: int + ) -> None: + for quantityunitpair in quantityunitpairs[1:]: + verticalpositionindex = quantityunitpair.vertpositionindex + + if not T3D._is_valid_verticalpositionindex( + verticalpositionindex, maximum_verticalpositionindex + ): + raise ValueError( + f"Vertical position index should be between 1 and {maximum_verticalpositionindex}, but {verticalpositionindex} was given" + ) + + @staticmethod + def _validate_verticalpositionindexes_and_update_quantityunitpairs( + verticalpositionindexes: List[int], + number_of_verticalpositions: int, + quantityunitpairs: List[QuantityUnitPair], + ) -> None: + if verticalpositionindexes is None: + raise ValueError("vertPositionIndex is not provided") + + if len(verticalpositionindexes) != len(quantityunitpairs) - 1: + raise ValueError( + "Number of vertical position indexes should be equal to the number of quantities/units - 1" + ) + + T3D._validate_that_verticalpositionindexes_are_valid( + verticalpositionindexes, number_of_verticalpositions + ) + + T3D._add_verticalpositionindex_to_quantityunitpairs( + quantityunitpairs[1:], verticalpositionindexes + ) + + @staticmethod + def _validate_that_verticalpositionindexes_are_valid( + verticalpositionindexes: List[int], number_of_vertical_positions: int + ) -> None: + for verticalpositionindexstring in verticalpositionindexes: + verticalpositionindex = ( + int(verticalpositionindexstring) + if verticalpositionindexstring + else None + ) + if not T3D._is_valid_verticalpositionindex( + verticalpositionindex, number_of_vertical_positions + ): + raise ValueError( + f"Vertical position index should be between 1 and {number_of_vertical_positions}" + ) + + @staticmethod + def _is_valid_verticalpositionindex( + verticalpositionindex: int, number_of_vertical_positions: int + ) -> bool: + one_based_index_offset = 1 + + return ( + verticalpositionindex is not None + and verticalpositionindex >= one_based_index_offset + and verticalpositionindex <= number_of_vertical_positions + ) + + @staticmethod + def _add_verticalpositionindex_to_quantityunitpairs( + quantityunitpairs: List[QuantityUnitPair], verticalpositionindexes: List[int] + ) -> None: + if len(quantityunitpairs) != len(verticalpositionindexes): + raise ValueError( + "Number of quantityunitpairs and verticalpositionindexes should be equal" + ) + + for (quantityunitpair, verticalpositionindex) in zip( + quantityunitpairs, verticalpositionindexes + ): + quantityunitpair.vertpositionindex = verticalpositionindex + class QHTable(ForcingBase): """Subclass for a .bc file [Forcing] block with Q-h table data.""" @@ -250,13 +452,18 @@ class Constant(ForcingBase): function: Literal["constant"] = "constant" offset: float = Field(0.0, alias="offset") + """float: All values in the table are increased by the offset (after multiplication by factor). Defaults to 0.0.""" + factor: float = Field(1.0, alias="factor") + """float: All values in the table are multiplied with the factor. Defaults to 1.0.""" class ForcingGeneral(INIGeneral): """`[General]` section with .bc file metadata.""" fileversion: str = Field("1.01", alias="fileVersion") + """str: The file version.""" + filetype: Literal["boundConds"] = Field("boundConds", alias="fileType") @@ -266,16 +473,15 @@ class ForcingModel(INIModel): This model is for example referenced under a [ExtModel][hydrolib.core.io.ext.models.ExtModel]`.boundary[..].forcingfile[..]`. - - Attributes: - general (ForcingGeneral): `[General]` block with file metadata. - forcing (List[ForcingBase]): List of `[Forcing]` blocks for all forcing - definitions in a single .bc file. Actual data is stored in - forcing[..].datablock from [hydrolib.core.io.ini.models.DataBlockINIBasedModel.datablock] or [hydrolib.core.io.ini.models.DataBlockINIBasedModel]. """ general: ForcingGeneral = ForcingGeneral() + """ForcingGeneral: `[General]` block with file metadata.""" + forcing: List[ForcingBase] = [] + """List[ForcingBase]: List of `[Forcing]` blocks for all forcing + definitions in a single .bc file. Actual data is stored in + forcing[..].datablock from [hydrolib.core.io.ini.models.DataBlockINIBasedModel.datablock].""" _split_to_list = make_list_validator("forcing") diff --git a/hydrolib/core/io/ini/util.py b/hydrolib/core/io/ini/util.py index 0fca0b861..e9b706206 100644 --- a/hydrolib/core/io/ini/util.py +++ b/hydrolib/core/io/ini/util.py @@ -462,3 +462,28 @@ def is_valid_coordinates_with_num_coordinates_specification() -> bool: raise ValueError(error) return root_validator(allow_reuse=True)(validate_location_specification) + + +def get_key_renaming_root_validator(keys_to_rename: Dict[str, List[str]]): + """ + Gets a root validator that renames the provided keys to support backwards compatibility. + + Args: + keys_to_rename Dict[str, List[str]]: Dictionary of keys and a list of old keys that + should be converted to the current key. + """ + + def rename_keys(cls, values: Dict) -> Dict: + for current_keyword, old_keywords in keys_to_rename.items(): + if current_keyword in values: + continue + + for old_keyword in old_keywords: + if (value := values.get(old_keyword)) is not None: + values[current_keyword] = value + del values[old_keyword] + break + + return values + + return root_validator(allow_reuse=True, pre=True)(rename_keys) diff --git a/hydrolib/core/io/mdu/models.py b/hydrolib/core/io/mdu/models.py index d4e4a44da..83d7c3d17 100644 --- a/hydrolib/core/io/mdu/models.py +++ b/hydrolib/core/io/mdu/models.py @@ -30,6 +30,7 @@ class AutoStartOption(IntEnum): Enum class containing the valid values for the AutoStart attribute in the [General][hydrolib.core.io.mdu.models.General] class. """ + no = 0 autostart = 1 autostartstop = 2 diff --git a/tests/data/reference/bc/t3d_backwards_compatibility.bc b/tests/data/reference/bc/t3d_backwards_compatibility.bc new file mode 100644 index 000000000..55ab69a4e --- /dev/null +++ b/tests/data/reference/bc/t3d_backwards_compatibility.bc @@ -0,0 +1,43 @@ +# written by HYDROLIB-core 0.3.0 + +[General] +fileVersion = 1.01 +fileType = boundConds + +[Forcing] +name = boundary_timeseries +function = timeseries +Time Interpolation = blockTo +offset = 1.23 +factor = 2.34 +quantity = time +unit = minutes since 2015-01-01 00:00:00 +quantity = dischargebnd +unit = m³/s +0.0 1.23 +60.0 2.34 +120.0 3.45 + +[Forcing] +name = boundary_t3d +function = t3d +offset = 1.23 +factor = 2.34 +Vertical Position Specification = 3.45 4.56 5.67 +Vertical Interpolation = log +Vertical Position Type = percBed +Time Interpolation = linear +quantity = time +unit = m +quantity = salinitybnd +unit = ppt +Vertical Position = 1 +quantity = salinitybnd +unit = ppt +Vertical Position = 2 +quantity = salinitybnd +unit = ppt +Vertical Position = 3 +0.0 1.0 2.0 3.0 +60.0 4.0 5.0 6.0 +120.0 7.0 8.0 9.0 diff --git a/tests/data/reference/bc/test.bc b/tests/data/reference/bc/test.bc index 55635ceb0..3aa4b2771 100644 --- a/tests/data/reference/bc/test.bc +++ b/tests/data/reference/bc/test.bc @@ -44,21 +44,25 @@ unit = deg 60.0 3.45 4.56 [Forcing] -name = boundary_t3d -function = t3d -offset = 1.23 -factor = 2.34 -verticalPositions = 3.45 4.56 5.67 -verticalInterpolation = log -verticalPositionType = percBed -quantity = time -unit = m -quantity = salinitybnd -unit = ppt -quantity = salinitybnd -unit = ppt -quantity = salinitybnd -unit = ppt +name = boundary_t3d +function = t3d +offset = 1.23 +factor = 2.34 +vertPositions = 3.45 4.56 5.67 +vertInterpolation = log +vertPositionType = percBed +timeInterpolation = linear +quantity = time +unit = m +quantity = salinitybnd +unit = ppt +vertPositionIndex = 1 +quantity = salinitybnd +unit = ppt +vertPositionIndex = 2 +quantity = salinitybnd +unit = ppt +vertPositionIndex = 3 0.0 1.0 2.0 3.0 60.0 4.0 5.0 6.0 120.0 7.0 8.0 9.0 diff --git a/tests/io/ini/test_util.py b/tests/io/ini/test_util.py index 67ae1872f..935b3eaca 100644 --- a/tests/io/ini/test_util.py +++ b/tests/io/ini/test_util.py @@ -1,12 +1,14 @@ from typing import Dict, List, Optional import pytest +from pydantic import Extra from pydantic.error_wrappers import ValidationError from hydrolib.core.basemodel import BaseModel from hydrolib.core.io.ini.util import ( LocationValidationConfiguration, LocationValidationFieldNames, + get_key_renaming_root_validator, get_location_specification_rootvalidator, ) @@ -198,3 +200,42 @@ def test_correct_1d_fields_locationtype_is_added( values ) assert validated_values == expected_values + + +class TestGetKeyRenamingRootValidator: + class DummyModel(BaseModel): + """Dummy model to test the validation of the location specification.""" + + randomproperty: str + + validator = get_key_renaming_root_validator( + { + "randomproperty": [ + "randomProperty", + "random_property", + "oldRandomProperty", + ], + } + ) + + class Config: + extra = Extra.allow + + @pytest.mark.parametrize( + "old_key", ["randomProperty", "random_property", "oldRandomProperty"] + ) + def test_old_keys_are_correctly_renamed_to_current_keyword(self, old_key: str): + values = {old_key: "randomString"} + + model = TestGetKeyRenamingRootValidator.DummyModel(**values) + + assert model.randomproperty == "randomString" + + def test_unknown_key_still_raises_error(self): + values = {"randomKeyThatNeverExisted": "randomString"} + + with pytest.raises(ValidationError) as error: + TestGetKeyRenamingRootValidator.DummyModel(**values) + + expected_message = "field required" + assert expected_message in str(error.value) diff --git a/tests/io/test_bc.py b/tests/io/test_bc.py index d4b247f64..2c06d4211 100644 --- a/tests/io/test_bc.py +++ b/tests/io/test_bc.py @@ -1,5 +1,6 @@ import inspect from pathlib import Path +from typing import List import pytest from pydantic.error_wrappers import ValidationError @@ -20,6 +21,7 @@ VerticalInterpolation, VerticalPositionType, ) +from hydrolib.core.io.ini.models import BaseModel from hydrolib.core.io.ini.parser import Parser, ParserConfig from ..utils import ( @@ -30,12 +32,26 @@ test_reference_dir, ) +TEST_BC_FILE = "test.bc" +TEST_BC_FILE_KEYWORDS_WITH_SPACES = "t3d_backwards_compatibility.bc" + class TestQuantityUnitPair: def test_create_quantityunitpair(self): pair = QuantityUnitPair(quantity="some_quantity", unit="some_unit") + assert isinstance(pair, BaseModel) + assert pair.quantity == "some_quantity" + assert pair.unit == "some_unit" + assert pair.vertpositionindex is None + + def test_create_quantityunitpair_with_verticalpositionindex(self): + pair = QuantityUnitPair( + quantity="some_quantity", unit="some_unit", vertpositionindex=123 + ) + assert isinstance(pair, BaseModel) assert pair.quantity == "some_quantity" assert pair.unit == "some_unit" + assert pair.vertpositionindex == 123 class TestTimeSeries: @@ -92,6 +108,31 @@ def test_read_bc_expected_result(self): assert forcing.quantityunitpair[1].unit == "m" assert forcing.datablock[1] == [1440.0, 2.5] + def test_load_timeseries_model_with_old_keyword_that_contain_spaces(self): + bc_file = Path(test_reference_dir / "bc" / TEST_BC_FILE_KEYWORDS_WITH_SPACES) + forcingmodel = ForcingModel(bc_file) + + timeseries = next( + (x for x in forcingmodel.forcing if x.function == "timeseries"), None + ) + assert timeseries is not None + assert timeseries.name == "boundary_timeseries" + assert timeseries.timeinterpolation == TimeInterpolation.block_to + assert timeseries.offset == 1.23 + assert timeseries.factor == 2.34 + + quantityunitpairs = timeseries.quantityunitpair + assert len(quantityunitpairs) == 2 + assert quantityunitpairs[0].quantity == "time" + assert quantityunitpairs[0].unit == "minutes since 2015-01-01 00:00:00" + assert quantityunitpairs[1].quantity == "dischargebnd" + + assert timeseries.datablock == [ + [0.0, 1.23], + [60.0, 2.34], + [120.0, 3.45], + ] + class TestForcingBase: @pytest.mark.parametrize( @@ -183,24 +224,24 @@ def test_forcing_model(self): assert isinstance(m.forcing[-1], TimeSeries) def test_read_bc_missing_field_raises_correct_error(self): - file = "missing_field.bc" + bc_file = "missing_field.bc" identifier = "Boundary2" - filepath = invalid_test_data_dir / file + filepath = invalid_test_data_dir / bc_file with pytest.raises(ValidationError) as error: ForcingModel(filepath) - expected_message1 = f"{file} -> forcing -> 1 -> {identifier}" + expected_message1 = f"{bc_file} -> forcing -> 1 -> {identifier}" expected_message2 = "quantity is not provided" assert expected_message1 in str(error.value) assert expected_message2 in str(error.value) def test_save_forcing_model(self): - file = Path(test_output_dir / "test.bc") - reference_file = Path(test_reference_dir / "bc" / "test.bc") + bc_file = Path(test_output_dir / TEST_BC_FILE) + reference_file = Path(test_reference_dir / "bc" / TEST_BC_FILE) forcingmodel = ForcingModel() - forcingmodel.filepath = file + forcingmodel.filepath = bc_file timeseries = TimeSeries(**_create_time_series_values()) harmonic = Harmonic(**_create_harmonic_values(False)) @@ -217,8 +258,8 @@ def test_save_forcing_model(self): forcingmodel.forcing.append(constant) forcingmodel.save() - assert file.is_file() == True - assert_files_equal(file, reference_file, skip_lines=[0, 3]) + assert bc_file.is_file() == True + assert_files_equal(bc_file, reference_file, skip_lines=[0, 3]) @pytest.mark.parametrize("cls", [Astronomic, AstronomicCorrection]) def test_astronomic_values_with_strings_in_datablock_are_parsed_correctly( @@ -240,6 +281,7 @@ class TestT3D: ("percBed", VerticalPositionType.percentage_bed), ("ZBed", VerticalPositionType.z_bed), ("ZDatum", VerticalPositionType.z_datum), + ("ZSurf", VerticalPositionType.z_surf), ], ) def test_initialize_t3d( @@ -248,7 +290,7 @@ def test_initialize_t3d( exp_vertical_position_type: VerticalPositionType, ): values = _create_t3d_values() - values["verticalpositiontype"] = vertical_position_type + values["vertpositiontype"] = vertical_position_type t3d = T3D(**values) @@ -256,25 +298,26 @@ def test_initialize_t3d( assert t3d.function == "t3d" assert t3d.offset == 1.23 assert t3d.factor == 2.34 + assert t3d.timeinterpolation == TimeInterpolation.linear - assert len(t3d.verticalpositions) == 3 - assert t3d.verticalpositions[0] == 3.45 - assert t3d.verticalpositions[1] == 4.56 - assert t3d.verticalpositions[2] == 5.67 + assert len(t3d.vertpositions) == 3 + assert t3d.vertpositions[0] == 3.45 + assert t3d.vertpositions[1] == 4.56 + assert t3d.vertpositions[2] == 5.67 - assert t3d.verticalinterpolation == VerticalInterpolation.log - assert t3d.verticalpositiontype == exp_vertical_position_type + assert t3d.vertinterpolation == VerticalInterpolation.log + assert t3d.vertpositiontype == exp_vertical_position_type assert len(t3d.quantityunitpair) == 4 assert t3d.quantityunitpair[0] == QuantityUnitPair(quantity="time", unit="m") assert t3d.quantityunitpair[1] == QuantityUnitPair( - quantity="salinitybnd", unit="ppt" + quantity="salinitybnd", unit="ppt", vertpositionindex=1 ) assert t3d.quantityunitpair[2] == QuantityUnitPair( - quantity="salinitybnd", unit="ppt" + quantity="salinitybnd", unit="ppt", vertpositionindex=2 ) assert t3d.quantityunitpair[3] == QuantityUnitPair( - quantity="salinitybnd", unit="ppt" + quantity="salinitybnd", unit="ppt", vertpositionindex=3 ) assert len(t3d.datablock) == 3 @@ -282,6 +325,312 @@ def test_initialize_t3d( assert t3d.datablock[1] == [60, 4, 5, 6] assert t3d.datablock[2] == [120, 7, 8, 9] + def test_create_t3d_first_quantity_not_time_raises_error(self): + values = _create_t3d_values() + + values["quantityunitpair"] = [ + _create_quantityunitpair("salinitybnd", "ppt"), + _create_quantityunitpair("time", "m"), + ] + + with pytest.raises(ValidationError) as error: + T3D(**values) + + expected_message = "First quantity should be `time`" + assert expected_message in str(error.value) + + def test_create_t3d_time_quantity_with_verticalpositionindex_raises_error(self): + values = _create_t3d_values() + + values["quantityunitpair"] = [ + _create_quantityunitpair("time", "m", 1), + ] + + with pytest.raises(ValidationError) as error: + T3D(**values) + + expected_message = "`time` quantity cannot have vertical position index" + assert expected_message in str(error.value) + + def test_create_t3d_verticalpositionindex_missing_for_non_time_unit_raises_error( + self, + ): + values = _create_t3d_values() + + values["quantityunitpair"] = [ + _create_quantityunitpair("time", "m"), + _create_quantityunitpair("salinitybnd", "ppt", None), + ] + + with pytest.raises(ValidationError) as error: + T3D(**values) + + expected_maximum_index = len(values["vertpositions"].split()) + expected_message = ( + f"Vertical position index should be between 1 and {expected_maximum_index}" + ) + assert expected_message in str(error.value) + + @pytest.mark.parametrize( + "vertpositions, verticalpositionindexes", + [ + pytest.param([1.23, 4.56], [0, 1], id="vertpositionindex is one-based"), + pytest.param( + [1.23, 4.56], + [1, 3], + id="vertpositionindex bigger than vertpositions length", + ), + pytest.param([1.23, 4.56], [1, None], id="too few vertpositionindex"), + pytest.param([1.23, 4.56], [1, 2, 3], id="too many vertpositionindex"), + ], + ) + def test_create_t3d_verticalposition_in_quantityunitpair_has_invalid_value_raises_error( + self, + vertpositions: List[float], + verticalpositionindexes: List[int], + ): + values = _create_t3d_values() + + time_quantityunitpair = [_create_quantityunitpair("time", "m")] + other_quantutyunitpairs = [] + for i in range(len(verticalpositionindexes)): + other_quantutyunitpairs.append( + _create_quantityunitpair( + "randomQuantity", "randomUnit", verticalpositionindexes[i] + ) + ) + + values["quantityunitpair"] = time_quantityunitpair + other_quantutyunitpairs + values["vertpositions"] = vertpositions + + with pytest.raises(ValidationError) as error: + T3D(**values) + + maximum_verticalpositionindex = len(vertpositions) + expected_message = f"Vertical position index should be between 1 and {maximum_verticalpositionindex}" + assert expected_message in str(error.value) + + @pytest.mark.parametrize( + "number_of_quantities_and_units, number_of_verticalpositionindexes", + [ + pytest.param(4, 2, id="4 quantities, but 2 verticalpositionindexes"), + pytest.param(2, 3, id="2 quantities, but 3 verticalpositionindexes"), + ], + ) + def test_create_t3d_number_of_verticalindexpositions_not_as_expected_raises_error( + self, + number_of_quantities_and_units: int, + number_of_verticalpositionindexes: int, + ): + values = _create_t3d_values() + + del values["quantityunitpair"] + + onebased_index_offset = 1 + + values["quantity"] = ["time"] + [ + str(i + onebased_index_offset) + for i in range(number_of_quantities_and_units) + ] + values["unit"] = ["m"] + [ + str(i + onebased_index_offset) + for i in range(number_of_quantities_and_units) + ] + values["vertpositionindex"] = [None] + [ + str(i + onebased_index_offset) + for i in range(number_of_verticalpositionindexes) + ] + + with pytest.raises(ValidationError) as error: + T3D(**values) + + expected_message = "Number of vertical position indexes should be equal to the number of quantities/units - 1" + assert expected_message in str(error.value) + + @pytest.mark.parametrize( + "verticalpositionindexes", + [ + pytest.param([0], id="vertpositionindex is one-based"), + pytest.param([None], id="vertpositionindex cannot be None"), + pytest.param( + [4], + id="vertpositionindex cannot be larger than number of vertical positions", + ), + ], + ) + def test_create_t3d_verticalpositionindex_has_invalid_value_raises_error( + self, verticalpositionindexes: List[int] + ): + values = _create_t3d_values() + + del values["quantityunitpair"] + + values["quantity"] = ["time", "randomQuantity"] + values["unit"] = ["randomUnit", "randomUnit"] + values["vertpositionindex"] = verticalpositionindexes + + with pytest.raises(ValidationError) as error: + T3D(**values) + + number_of_vertical_positions = len(values["vertpositions"].split()) + expected_message = f"Vertical position index should be between 1 and {number_of_vertical_positions}" + assert expected_message in str(error.value) + + def test_create_t3d_creates_correct_quantityunitpairs(self): + values = _create_t3d_values() + + t3d = T3D(**values) + + quantityunitpairs = t3d.quantityunitpair + expected_quantityunitpairs = values["quantityunitpair"] + + TestT3D._validate_that_correct_quantityunitpairs_are_created( + quantityunitpairs, expected_quantityunitpairs + ) + + def test_create_t3d_creates_correct_quantityunitpairs_using_verticalpositionindexes( + self, + ): + values = _create_t3d_values() + + del values["quantityunitpair"] + + values["quantity"] = ["time", "randomQuantity1", "randomQuantity2"] + values["unit"] = ["randomUnit", "randomUnit", "randomUnit"] + values["vertpositionindex"] = [2, 3] + + t3d = T3D(**values) + + quantityunitpairs = t3d.quantityunitpair + expected_quantityunitpairs = [] + + for quantity, unit, verticalpositionindex in zip( + values["quantity"], values["unit"], [None] + values["vertpositionindex"] + ): + expected_quantityunitpairs.append( + _create_quantityunitpair(quantity, unit, verticalpositionindex) + ) + + TestT3D._validate_that_correct_quantityunitpairs_are_created( + quantityunitpairs, expected_quantityunitpairs + ) + + def test_create_t3d_timeinterpolation_defaults_to_linear(self): + values = _create_t3d_values() + + del values["timeinterpolation"] + + t3d = T3D(**values) + + assert t3d.timeinterpolation == "linear" + + def test_create_t3d_verticalinterpolation_defaults_to_linear(self): + values = _create_t3d_values() + + del values["vertinterpolation"] + + t3d = T3D(**values) + + assert t3d.vertinterpolation == "linear" + + def test_create_t3d_without_specifying_vertpositions_raises_error(self): + values = _create_t3d_values() + + del values["vertpositions"] + + with pytest.raises(ValidationError) as error: + T3D(**values) + + expected_message = "vertPositions is not provided" + assert expected_message in str(error.value) + + def test_load_forcing_model(self): + bc_file = Path(test_reference_dir / "bc" / TEST_BC_FILE) + forcingmodel = ForcingModel(bc_file) + + t3d = next((x for x in forcingmodel.forcing if x.function == "t3d"), None) + assert t3d is not None + assert t3d.name == "boundary_t3d" + assert t3d.offset == 1.23 + assert t3d.factor == 2.34 + assert t3d.vertpositions == [3.45, 4.56, 5.67] + assert t3d.vertinterpolation == "log" + assert t3d.vertpositiontype == "percBed" + assert t3d.timeinterpolation == "linear" + + quantityunitpairs = t3d.quantityunitpair + assert len(quantityunitpairs) == 4 + assert quantityunitpairs[0].quantity == "time" + assert quantityunitpairs[0].unit == "m" + assert quantityunitpairs[0].vertpositionindex == None + assert quantityunitpairs[1].quantity == "salinitybnd" + assert quantityunitpairs[1].unit == "ppt" + assert quantityunitpairs[1].vertpositionindex == 1 + assert quantityunitpairs[2].quantity == "salinitybnd" + assert quantityunitpairs[2].unit == "ppt" + assert quantityunitpairs[2].vertpositionindex == 2 + assert quantityunitpairs[3].quantity == "salinitybnd" + assert quantityunitpairs[3].unit == "ppt" + assert quantityunitpairs[3].vertpositionindex == 3 + + assert t3d.datablock == [ + [0.0, 1.0, 2.0, 3.0], + [60.0, 4.0, 5.0, 6.0], + [120.0, 7.0, 8.0, 9.0], + ] + + def test_load_t3d_model_with_old_keyword_that_contains_spaces(self): + bc_file = Path(test_reference_dir / "bc" / TEST_BC_FILE_KEYWORDS_WITH_SPACES) + forcingmodel = ForcingModel(bc_file) + + t3d = next((x for x in forcingmodel.forcing if x.function == "t3d"), None) + assert t3d is not None + assert t3d.name == "boundary_t3d" + assert t3d.offset == 1.23 + assert t3d.factor == 2.34 + assert t3d.vertpositions == [3.45, 4.56, 5.67] + assert t3d.vertinterpolation == "log" + assert t3d.vertpositiontype == "percBed" + assert t3d.timeinterpolation == "linear" + + quantityunitpairs = t3d.quantityunitpair + assert len(quantityunitpairs) == 4 + assert quantityunitpairs[0].quantity == "time" + assert quantityunitpairs[0].unit == "m" + assert quantityunitpairs[0].vertpositionindex == None + assert quantityunitpairs[1].quantity == "salinitybnd" + assert quantityunitpairs[1].unit == "ppt" + assert quantityunitpairs[1].vertpositionindex == 1 + assert quantityunitpairs[2].quantity == "salinitybnd" + assert quantityunitpairs[2].unit == "ppt" + assert quantityunitpairs[2].vertpositionindex == 2 + assert quantityunitpairs[3].quantity == "salinitybnd" + assert quantityunitpairs[3].unit == "ppt" + assert quantityunitpairs[3].vertpositionindex == 3 + + assert t3d.datablock == [ + [0.0, 1.0, 2.0, 3.0], + [60.0, 4.0, 5.0, 6.0], + [120.0, 7.0, 8.0, 9.0], + ] + + @staticmethod + def _validate_that_correct_quantityunitpairs_are_created( + quantityunitpairs: List[QuantityUnitPair], + expected_quantityunitpairs: List[QuantityUnitPair], + ): + assert len(quantityunitpairs) == len(expected_quantityunitpairs) + + for quantityunitpair, expected_quantityunitpair in zip( + quantityunitpairs, expected_quantityunitpairs + ): + assert quantityunitpair.quantity == expected_quantityunitpair.quantity + assert quantityunitpair.unit == expected_quantityunitpair.unit + assert ( + quantityunitpair.vertpositionindex + == expected_quantityunitpair.vertpositionindex + ) + def _create_time_series_values(): return dict( @@ -291,8 +640,8 @@ def _create_time_series_values(): offset="1.23", factor="2.34", quantityunitpair=[ - ("time", "minutes since 2015-01-01 00:00:00"), - ("dischargebnd", "m³/s"), + _create_quantityunitpair("time", "minutes since 2015-01-01 00:00:00"), + _create_quantityunitpair("dischargebnd", "m³/s"), ], datablock=[["0", "1.23"], ["60", "2.34"], ["120", "3.45"]], ) @@ -304,9 +653,9 @@ def _create_harmonic_values(iscorrection: bool): name=f"boundary_{function}", function=function, quantityunitpair=[ - ("harmonic component", "minutes"), - ("waterlevelbnd amplitude", "m"), - ("waterlevelbnd phase", "deg"), + _create_quantityunitpair("harmonic component", "minutes"), + _create_quantityunitpair("waterlevelbnd amplitude", "m"), + _create_quantityunitpair("waterlevelbnd phase", "deg"), ], datablock=[ ["0", "1.23", "2.34"], @@ -321,9 +670,9 @@ def _create_astronomic_values(iscorrection: bool): name=f"boundary_{function}", function=function, quantityunitpair=[ - ("astronomic component", "-"), - ("waterlevelbnd amplitude", "m"), - ("waterlevelbnd phase", "deg"), + _create_quantityunitpair("astronomic component", "-"), + _create_quantityunitpair("waterlevelbnd amplitude", "m"), + _create_quantityunitpair("waterlevelbnd phase", "deg"), ], datablock=[ ["A0", "1.23", "2.34"], @@ -333,20 +682,27 @@ def _create_astronomic_values(iscorrection: bool): ) +def _create_quantityunitpair(quantity, unit, verticalpositionindex=None): + return QuantityUnitPair( + quantity=quantity, unit=unit, vertpositionindex=verticalpositionindex + ) + + def _create_t3d_values(): return dict( name="boundary_t3d", function="t3d", offset="1.23", factor="2.34", - verticalpositions="3.45 4.56 5.67", - verticalinterpolation="log", - verticalpositiontype="percBed", + vertpositions="3.45 4.56 5.67", + vertinterpolation="log", + vertpositiontype="percBed", + timeinterpolation="linear", quantityunitpair=[ - ("time", "m"), - ("salinitybnd", "ppt"), - ("salinitybnd", "ppt"), - ("salinitybnd", "ppt"), + _create_quantityunitpair("time", "m"), + _create_quantityunitpair("salinitybnd", "ppt", 1), + _create_quantityunitpair("salinitybnd", "ppt", 2), + _create_quantityunitpair("salinitybnd", "ppt", 3), ], datablock=[ ["0", "1", "2", "3"], @@ -361,8 +717,8 @@ def _create_qhtable_values(): name="boundary_qhtable", function="qhtable", quantityunitpair=[ - ("qhbnd discharge", "m3/s"), - ("qhbnd waterlevel", "m"), + _create_quantityunitpair("qhbnd discharge", "m3/s"), + _create_quantityunitpair("qhbnd waterlevel", "m"), ], datablock=[ ["1.23", "2.34"], @@ -379,7 +735,7 @@ def _create_constant_values(): factor="2.34", timeinterpolation="linear", quantityunitpair=[ - ("waterlevelbnd", "m"), + _create_quantityunitpair("waterlevelbnd", "m"), ], datablock=[ ["3.45"], diff --git a/tests/test_model.py b/tests/test_model.py index 06c45f8ae..51e2782b5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ from pydantic.error_wrappers import ValidationError from hydrolib.core.basemodel import DiskOnlyFileModel -from hydrolib.core.io.bc.models import ForcingBase, ForcingModel +from hydrolib.core.io.bc.models import ForcingBase, ForcingModel, QuantityUnitPair from hydrolib.core.io.dimr.models import ( DIMR, ComponentOrCouplerRef, @@ -349,7 +349,11 @@ def test_boundary_with_forcing_file_without_match_returns_none(): def _create_forcing(name: str, quantity: str) -> ForcingBase: - return ForcingBase(name=name, quantityunitpair=[(quantity, "")], function="") + return ForcingBase( + name=name, + quantityunitpair=[QuantityUnitPair(quantity=quantity, unit="")], + function="", + ) def _create_boundary(data: Dict) -> Boundary: