Skip to content

Commit

Permalink
update increment timestep
Browse files Browse the repository at this point in the history
#68
todo: deal with timesteps and potentially change typehints
update tests
  • Loading branch information
sjordan29 committed Feb 27, 2024
1 parent 52ec704 commit 9b06e86
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 40 deletions.
11 changes: 6 additions & 5 deletions examples/dev_sandbox/performance_profiling_tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def run_performance_test(

# instantiate the TSM module
tsm = clearwater_modules.tsm.EnergyBudget(
time_steps=iters,
initial_state_values=state_i,
meteo_parameters=meteo_parameters,
)
Expand Down Expand Up @@ -89,9 +90,9 @@ def run_performance_test(
sys.exit(1)

log_file = sys.argv[1]
# iterations_list = [1, 10, 100, 1000, 10000, 100000]
# gridsize_list = [1, 1000, 10000]
iterations_list = [10000]
gridsize_list = [10000]
detailed_profile = True
iterations_list = [1, 10, 100, 1000, 10000, 100000]
gridsize_list = [1, 1000, 10000]
# iterations_list = [10000]
# gridsize_list = [10000]
detailed_profile = False
run_performance_test(iterations_list, gridsize_list, log_file, detailed_profile)
19 changes: 14 additions & 5 deletions examples/dev_sandbox/prof.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import numpy as np


def main(iters: int, baseline: bool):
def main(iters: int, type: str):
ti = time.time()
# define starting state values
if baseline:
if type == 'baseline':
state_i = {
'water_temp_c': 40.0,
'surface_area': 1.0,
'volume': 1.0,
}
else:
elif type in ['arrays', 'hotstart']:
state_i = {
'water_temp_c': xr.DataArray(
np.full(10, 40),
Expand All @@ -30,6 +30,7 @@ def main(iters: int, baseline: bool):
dims='cell',
coords={'cell': np.arange(10)}),
}


# instantiate the TSM module
tsm = clearwater_modules.tsm.EnergyBudget(
Expand All @@ -38,8 +39,16 @@ def main(iters: int, baseline: bool):
meteo_parameters={'wind_c': 1.0},
updateable_static_variables=['wind_c']
)
print(tsm.static_variable_values)

t2 = time.time()

if type == 'hotstart':
tsm = clearwater_modules.tsm.EnergyBudget(
time_steps=iters,
hotstart_dataset=tsm.dataset,
)
t2 = time.time()

for _ in range(iters):
tsm.increment_timestep()
print(f'Increment timestep speed (average of {iters}): {(time.time() - t2) / 100}')
Expand All @@ -56,4 +65,4 @@ def main(iters: int, baseline: bool):
print('No argument given, defaulting to 100 iteration.')
iters = 100

main(iters=iters, baseline=True)
main(iters=iters, type='baseline')
84 changes: 54 additions & 30 deletions src/clearwater_modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
track_dynamic_variables: bool = True,
hotstart_dataset: Optional[xr.Dataset] = None,
time_dim: Optional[str] = None,
timestep: Optional[int] = 0,
) -> None:
"""Initialize the model, should be accessed by subclasses.
Expand All @@ -62,6 +63,8 @@ def __init__(
self.static_variable_values = static_variable_values
self.hotstart_dataset = hotstart_dataset
self.track_dynamic_variables = track_dynamic_variables
self.timestep = timestep
self.time_steps = time_steps + 1 # xarray indexing

if not time_dim:
time_dim = 'time_step'
Expand All @@ -78,7 +81,7 @@ def __init__(
initial_state_values=self.initial_state_values,
static_variable_values=self.static_variable_values,
updateable_static_variables=self.updateable_static_variables,
time_steps=time_steps,
time_steps=self.time_steps,
)

elif isinstance(hotstart_dataset, xr.Dataset):
Expand Down Expand Up @@ -374,52 +377,75 @@ def _non_updateable_static_variables(self) -> list[str]:
var.name for var in self.static_variables if var.name not in self.updateable_static_variables
]
return self.__non_updateable_static_variables

def _iter_computations(self):
inputs = map(
lambda x: utils._prep_inputs(
self.dataset.isel({self.time_dim: self.timestep}),
x),
self.computation_order
)
for name, func, arrays in inputs:
array: np.ndarray = func(*arrays)
dims = self.dataset[name].dims
if self.time_dim in dims:
self.dataset[name].loc[{self.time_dim: self.timestep}] = array
else:
self.dataset[name] = (dims, array)


def increment_timestep(
self,
update_state_values: Optional[dict[str, xr.DataArray]] = None,
) -> xr.Dataset:
"""Run the process."""
self.timestep +=1

if update_state_values is None:
update_state_values = {}

# get the last timestep as a xr.DataArray
last_timestep: int = self.dataset[self.time_dim].values[-1]
timestep_ds: xr.Dataset = self.dataset.isel(
{self.time_dim: -1},
).copy(deep=True)

# by default, set current timestep equal to last timestep
self.dataset[self.state_variables_names + self.updateable_static_variables].loc[
{self.time_dim: self.timestep}
] = self.dataset[self.state_variables_names + self.updateable_static_variables].isel(
{self.time_dim: self.timestep - 1}
)

# update the state variables as necessary (i.e. interacting w/ other models)
for var_name, value in update_state_values.items():
if var_name not in (self.state_variables_names + self.updateable_static_variables):
raise ValueError(
f'Variable {var_name} cannot be updated between timesteps, skipping.',
)
utils.validate_arrays(value, timestep_ds[var_name])
timestep_ds[var_name] = value
utils.validate_arrays(
value,
self.dataset[var_name].isel(
{self.time_dim: self.timestep}
)
)
self.dataset[var_name].loc[{self.time_dim: self.timestep}] = value

# add dynamic variables to ds
for dynamic_variable in self.dynamic_variables_names:
self.dataset[dynamic_variable] = xr.DataArray(
np.full(
tuple(
self.dataset[self.static_variables_names[0]].sizes[dim]
for dim in self.dataset[self.static_variables_names[0]].dims),
np.nan
),
dims=self.dataset[self.static_variables_names[0]].dims
)

# compute the dynamic variables in order
timestep_ds = utils.iter_computations(
timestep_ds,
self.computation_order,
)
if not self.track_dynamic_variables:
timestep_ds = timestep_ds.drop_vars(self.dynamic_variables_names)

timestep_ds = timestep_ds.drop_vars(self._non_updateable_static_variables)
timestep_ds = timestep_ds.expand_dims(
{self.time_dim: [last_timestep + 1]},
)
self._iter_computations()

self.dataset = xr.concat(
[
self.dataset,
timestep_ds,
],
dim=self.time_dim,
data_vars='minimal',
)
if not self.track_dynamic_variables:
self.dataset.loc[{self.time_dim: self.timestep}] = self.dataset.isel(
{self.time_dim: self.timestep}
).drop_vars(
self.dynamic_variables_names
)

# add dynamic variable attributes
if self.track_dynamic_variables:
Expand All @@ -431,8 +457,6 @@ def increment_timestep(
'description': var.description,
}

return self.dataset


def register_variable(models: CanRegisterVariable | Iterable[CanRegisterVariable]):
"""A decorator to register a variable with a model."""
Expand Down

0 comments on commit 9b06e86

Please sign in to comment.