Skip to content

Commit

Permalink
initialize array when initializing model (2/3)
Browse files Browse the repository at this point in the history
TO DO: hot start dataset; increment_timestep
#68
  • Loading branch information
sjordan29 committed Feb 26, 2024
1 parent a530bc1 commit 52ec704
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 25 deletions.
35 changes: 28 additions & 7 deletions examples/dev_sandbox/prof.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@
import clearwater_modules
import time
import sys
import xarray as xr
import numpy as np

def main(iters: int):

def main(iters: int, baseline: bool):
ti = time.time()
# define starting state values
state_i = {
'water_temp_c': 40.0,
'surface_area': 1.0,
'volume': 1.0,
}
if baseline:
state_i = {
'water_temp_c': 40.0,
'surface_area': 1.0,
'volume': 1.0,
}
else:
state_i = {
'water_temp_c': xr.DataArray(
np.full(10, 40),
dims='cell',
coords={'cell': np.arange(10)}),
'surface_area': xr.DataArray(
np.full(10, 1.0),
dims='cell',
coords={'cell': np.arange(10)}),
'volume': xr.DataArray(
np.full(10, 1.0),
dims='cell',
coords={'cell': np.arange(10)}),
}

# instantiate the TSM module
tsm = clearwater_modules.tsm.EnergyBudget(
time_steps=iters,
initial_state_values=state_i,
meteo_parameters={'wind_c': 1.0},
updateable_static_variables=['wind_c']
)
print(tsm.static_variable_values)
t2 = time.time()
Expand All @@ -35,4 +56,4 @@ def main(iters: int):
print('No argument given, defaulting to 100 iteration.')
iters = 100

main(iters=iters)
main(iters=iters, baseline=True)
83 changes: 65 additions & 18 deletions src/clearwater_modules/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Stored base types shared by all sub-modules."""
import warnings
import xarray as xr
import numpy as np
import clearwater_modules.utils as utils
import clearwater_modules.sorter as sorter
from clearwater_modules.shared.types import (
Expand Down Expand Up @@ -28,6 +29,7 @@ class Model(CanRegisterVariable):

def __init__(
self,
time_steps: int,
initial_state_values: Optional[InitialVariablesDict] = None,
static_variable_values: Optional[InitialVariablesDict] = None,
updateable_static_variables: Optional[list[str]] = None,
Expand All @@ -38,6 +40,7 @@ def __init__(
"""Initialize the model, should be accessed by subclasses.
Args:
time_steps: An integer to indicate the number of timesteps to run.
initial_state_values: A dict with variable names as keys, and initial
state variables as values.
static_variable_values: A dict with variable names as keys, and static
Expand Down Expand Up @@ -75,6 +78,7 @@ def __init__(
initial_state_values=self.initial_state_values,
static_variable_values=self.static_variable_values,
updateable_static_variables=self.updateable_static_variables,
time_steps=time_steps,
)

elif isinstance(hotstart_dataset, xr.Dataset):
Expand All @@ -96,6 +100,7 @@ def _init_dataset_from_dicts(
initial_state_values: InitialVariablesDict,
static_variable_values: InitialVariablesDict,
updateable_static_variables: list[str],
time_steps: int,
) -> xr.Dataset:
"""Initialize Model.dataset from dicts."""
if not isinstance(initial_state_values, dict):
Expand Down Expand Up @@ -125,10 +130,14 @@ def _init_dataset_from_dicts(
initial_state_values[static] = static_variable_values.pop(static)

# initialize the main model dataset
dataset: xr.Dataset = self._init_state_arrays(initial_state_values)
dataset: xr.Dataset = self._init_state_arrays(
initial_state_values,
time_steps,
)
dataset: xr.Dataset = self._init_static_arrays(
dataset,
static_variable_values,
time_steps,
)

print('Model initialized from input dicts successfully!.')
Expand All @@ -145,10 +154,13 @@ def _init_from_dataset(self, hotstart_dataset: xr.Dataset) -> xr.Dataset:
def _init_state_arrays(
self,
initial_state_values: InitialVariablesDict,
time_steps: int,
) -> xr.Dataset:
"""Initializes the state arrays."""
match_dims: list[str] = []
data_arrays: dict[str, xr.DataArray] = {}
coords: dict = {}
add_data: list[str] = []

for k, v in initial_state_values.items():
if k not in (self.state_variables_names + self.updateable_static_variables):
Expand All @@ -161,37 +173,72 @@ def _init_state_arrays(
else:
utils.validate_arrays(v, *list(data_arrays.values()))
data_arrays[k] = v
coords = coords | dict(data_arrays[k].coords.items())
add_data.append(k)
if len(data_arrays) > 0:
array_i = list(data_arrays.values())[0]
ds = xr.Dataset(
data_vars={
k: (
data_arrays[k].dims + (self.time_dim,),
np.full(
tuple(data_arrays[k].sizes[dim] for dim in data_arrays[k].dims) + (time_steps,),
np.nan
)
)
for k in data_arrays.keys()
},
coords={
**coords,
self.time_dim: np.arange(time_steps),
}
)
else:
array_i = xr.DataArray(
[[1.0]],
dims=['x', 'y'],
coords=[[1.0], [1.0]],
ds = xr.Dataset(
data_vars={
k: (
('x', 'y', self.time_dim),
np.full((1, 1, time_steps), np.nan)
)
for k in match_dims
},
coords={'x': [1.0], 'y': [1.0], self.time_dim: np.arange(time_steps)}
)
for var_name in match_dims:

for var_name in match_dims + add_data:
variable = self.get_variable(var_name)
attrs = {
'long_name': variable.long_name,
'units': variable.units,
'description': variable.description,
}
data_arrays[var_name] = xr.full_like(
array_i,
initial_state_values[var_name],
dtype=type(initial_state_values[var_name]),
)
data_arrays[var_name].attrs = attrs
ds = xr.Dataset(
data_vars=data_arrays,
coords=array_i.coords,
)
return ds.expand_dims({self.time_dim: [0]})

if var_name not in data_arrays.keys():
ds[var_name] = xr.DataArray(
np.full(
tuple(ds.sizes[dim] for dim in ds.dims),
np.nan
),
dims=ds.dims
)

ds[var_name].loc[{self.time_dim: 0}] = xr.full_like(
ds[var_name].isel({self.time_dim: 0}),
initial_state_values[var_name],
dtype=type(initial_state_values[var_name]),
)

else:
ds[var_name].loc[{self.time_dim: 0}] = initial_state_values[var_name]

ds[var_name].attrs = attrs

return ds # ds.expand_dims({self.time_dim: np.arange(time_steps)})

def _init_static_arrays(
self,
dataset: xr.Dataset,
static_variable_values: InitialVariablesDict,
time_steps: int,
) -> xr.Dataset:
"""Broadcasts static variables to an existing dataset.
Expand Down
2 changes: 2 additions & 0 deletions src/clearwater_modules/tsm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class EnergyBudget(base.Model):

def __init__(
self,
time_steps: int,
initial_state_values: Optional[base.InitialVariablesDict] = None,
updateable_static_variables: Optional[list[str]] = None,
meteo_parameters: Optional[dict[str, float]] = None,
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
#static_variable_values['use_sed_temp'] = use_sed_temp

super().__init__(
time_steps=time_steps,
initial_state_values=initial_state_values,
static_variable_values=static_variable_values,
updateable_static_variables=updateable_static_variables,
Expand Down

0 comments on commit 52ec704

Please sign in to comment.