Skip to content

Commit

Permalink
hotstart tests #68
Browse files Browse the repository at this point in the history
  • Loading branch information
sjordan29 committed Mar 14, 2024
1 parent 30d0203 commit e14441d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
62 changes: 51 additions & 11 deletions src/clearwater_modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def __init__(
self.updateable_static_variables = updateable_static_variables
self.__non_updateable_static_variables: list[str] | None = None

# create list of temporal variables
if self._track_dynamic_variables:
self.temporal_variables = self.state_variables_names + \
self.updateable_static_variables + self.dynamic_variables_names
else:
self.temporal_variables = self.state_variables_names + \
self.updateable_static_variables

if isinstance(self.initial_state_values, dict) and isinstance(self.static_variable_values, dict):
print('Initializing from dicts...')
self.dataset: xr.Dataset = self._init_dataset_from_dicts(
Expand All @@ -89,6 +97,7 @@ def __init__(
print('Initializing from hotstart dataset...')
self.dataset: xr.Dataset = self._init_from_dataset(
hotstart_dataset,
time_steps,
)
self.hotstart_dataset = None

Expand Down Expand Up @@ -133,14 +142,6 @@ def _init_dataset_from_dicts(
else:
initial_state_values[static] = static_variable_values.pop(static)

# create list of temporal variables
if self._track_dynamic_variables:
self.temporal_variables = self.state_variables_names + \
self.updateable_static_variables + self.dynamic_variables_names
else:
self.temporal_variables = self.state_variables_names + \
self.updateable_static_variables

# initialize the main model dataset
dataset: xr.Dataset = self._init_state_arrays(
initial_state_values,
Expand All @@ -159,13 +160,52 @@ def _init_dataset_from_dicts(
print('Model initialized from input dicts successfully!.')
return dataset

def _init_from_dataset(self, hotstart_dataset: xr.Dataset) -> xr.Dataset:
def _init_from_dataset(
self,
hotstart_dataset: xr.Dataset,
time_steps: int
) -> xr.Dataset:
"""Initialize the model from a hotstart dataset."""
if self.time_dim not in hotstart_dataset.dims:
raise ValueError(
f'Hotstart dataset must have a {self.time_dim} dimension.'
)
return hotstart_dataset
else:
coords = {
key: value if key != self.time_dim
else np.arange(time_steps)
for key, value in hotstart_dataset.coords.items()
}

new_hotstart_dataset = xr.Dataset(
data_vars={
var_name: (
hotstart_dataset[var_name].dims,
np.full(
tuple(
hotstart_dataset[var_name].sizes[dim]
for dim in hotstart_dataset[var_name].dims
),
np.nan
)
)
for var_name, _ in hotstart_dataset.data_vars.items()
},
coords={
**coords
}
)

# set temporal variables to the last timestep of the hotstart dataset
new_hotstart_dataset[self.temporal_variables].loc[
{self.time_dim: 0}
] = hotstart_dataset[self.temporal_variables].isel(
{self.time_dim: -1}
)

new_hotstart_dataset[self._non_updateable_static_variables] = hotstart_dataset[self._non_updateable_static_variables]

return new_hotstart_dataset

def _init_state_arrays(
self,
Expand Down Expand Up @@ -426,7 +466,7 @@ def _non_updateable_static_variables(self) -> list[str]:
def track_dynamic_variables(self) -> bool:
"""Track dynamic variables property."""
return self._track_dynamic_variables

@track_dynamic_variables.setter
def track_dynamic_variables(self, value: bool) -> bool:
if self._track_dynamic_variables == value:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,20 @@ def test_variable_attributes(model: Model) -> None:
assert 'description' in model.dataset[var].attrs


def test_model_hotstart(model: Model) -> None:
def test_model_hotstart(
model: Model,
time_steps: int,
) -> None:
"""Test if the hotstart works."""
ds = model.increment_timestep()
ds.attrs['hotstart'] = True
ds = ds.isel(time_step=slice(0,2))

hotstart_model = MockModel(
time_steps=time_steps,
hotstart_dataset=ds,
)

assert isinstance(hotstart_model, Model)
assert len(hotstart_model.dataset[model.time_dim]) == 2
assert model.dataset.attrs.get('hotstart') == True
Expand Down

0 comments on commit e14441d

Please sign in to comment.