Skip to content

Commit

Permalink
Fixed issue with Config.duplicate. Added test. (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeitsperre committed May 14, 2024
2 parents 1f5be4e + 18c5876 commit 143fb95
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 33 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
History
=======

0.15.0 (unreleased)
-------------------
* Fixed bug in `Config.duplicate` dating from the switch to Pydantic V2 in 0.13 (PR #367)

0.14.1 (2024-05-07)
-------------------
* Upgraded `owslib` to `>=0.29.1`. (PR #358)
Expand Down
37 changes: 13 additions & 24 deletions ravenpy/config/rvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,51 +279,40 @@ def header(self, rv):

@field_validator("params", mode="before")
@classmethod
def _cast_to_dataclass(cls, data):
def _cast_to_dataclass(cls, data: Union[Dict, Sequence]):
"""Cast params to a dataclass."""
# Needed because pydantic v2 does not cast tuples automatically.
if data is not None and not is_dataclass(data):
if data is not None:
if is_dataclass(data):
return data

if isinstance(data, dict):
return cls.model_fields["params"].annotation(**data)

return cls.model_fields["params"].annotation(*data)
return data

@field_validator("global_parameter", mode="before")
@classmethod
def _update_defaults(cls, v, info: ValidationInfo):
"""Some configuration parameters should be updated with user given arguments, not overwritten."""
return {**cls.model_fields[info.field_name].default, **v}

# @model_validator(mode="after")
# def _parse_symbolic(self):
# """If params is numerical, convert symbolic expressions from other fields.
# """
#
# if self.params is not None:
# p = asdict(self.params)
#
# if not is_symbolic(p):
# for key in self.model_fields.keys():
# if key != "params":
# setattr(self, key, parse_symbolic(getattr(self, key), **p))
#
# return self

def set_params(self, params) -> "Config":
"""Return a new instance of Config with params set to their numerical values."""
def set_params(self, params: Union[Dict, Sequence]) -> "Config":
"""Return a new instance of Config with params frozen to their numerical values."""
# Create params with numerical values
if not self.is_symbolic:
raise ValueError(
"Setting `params` on a configuration without symbolic expressions has no effect."
"Leave `params` to its default value when instantiating the emulator configuration."
)

num_p = self.model_fields["params"].annotation(*params)
p = self.__class__._cast_to_dataclass(params)

# Parse symbolic expressions using numerical params values
out = parse_symbolic(self.__dict__, **asdict(num_p))
out["params"] = num_p
out = parse_symbolic(self.__dict__, **asdict(p))
out["params"] = p

# Instantiate config class
# Note: `construct` skips validation. benchmark to see if it speeds things up.
return self.__class__.model_construct(**out)

def set_solution(self, fn: Path, timestamp: bool = True) -> "Config":
Expand Down
20 changes: 11 additions & 9 deletions tests/test_rvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_rvi_datetime():
rvi = RVI(start_date=(dt.datetime(1990, 1, 1),))


def test_duplicate():
def test_duplicate_simple():
conf = Config(start_date="1990-01-01")

# Updating values with an alias and an attribute name
Expand All @@ -53,6 +53,13 @@ def test_duplicate():
assert out.debug_mode


def test_duplicate_emulator(gr4jcn_config):
conf, params = gr4jcn_config
conf.duplicate()

conf.duplicate(params=params)


def test_set_params():
@dataclass(config=dict(arbitrary_types_allowed=True))
class P:
Expand All @@ -65,19 +72,14 @@ class MySymbolicEmulator(Config):
alias="RainSnowTransition",
)

# Assignment through instantiation
# exp = MySymbolicEmulator(params=[0.5])
# assert exp.rain_snow_transition.temp == 0.5

# Assignment through set_params -> new instance
s = MySymbolicEmulator()
num = s.set_params([0.5])
assert num.rain_snow_transition.temp == 0.5

# Attribute assignment
# s.params = [0.5]
# assert s.rain_snow_transition == exp.rain_snow_transition
# assert s.rvp == exp.rvp
s = MySymbolicEmulator()
num = s.set_params({"X01": 0.5})
assert num.rain_snow_transition.temp == 0.5


def test_solution(get_local_testdata):
Expand Down

0 comments on commit 143fb95

Please sign in to comment.