# Testing different computation engine approaches

Computation engine = the part of the code that is shared across all modules. AKA, the inherited `base.increment_timestep()` function, which current passes `Model.computation_order` into `utils.iter_computations()`.

Note that `Model.computation_order` takes an optional dependency injection that overrides `utils.iter_computation()`, we use this to test different approaches.

**Testing approaches:**
* V1 approach: Have a `utils.iter_computations` has the following decorator `numba.jit(forceobj=True)` and attempts to JIT our custom `base.Variable` arguments as well as `xr.Dataset`.
* V2 approach: The same as V1, except not JIT compiled at all.
* V3 approach: A true JIT compile by using `numpy`, then writing back to `xarray`.

In [11]:
import clearwater_modules
import xarray as xr
import numpy as np

# Pull in test data

In [73]:
%%time
# get test xarray data
test_ds: xr.Dataset = xr.tutorial.load_dataset('air_temperature')
test_ds.attrs = {}
test_ds: xr.DataArray = test_ds.rename_vars(
    {'air': 'water_temp_c'}
).isel(time=slice(0, 10))
test_ds['time'] = np.arange(0, 10)
test_ds['surface_area'] = test_ds.water_temp_c / test_ds.water_temp_c
test_ds['volume'] = test_ds.water_temp_c / test_ds.water_temp_c

test_ds

CPU times: total: 31.2 ms
Wall time: 32 ms


# Compare methods of writing to xarray

In [91]:
%%timeit
for _ in range(50):
    timestep_ds = test_ds.isel(time=-1).copy(deep=True)
    timestep_ds['time'] = timestep_ds['time'] + 1
    timestep_ds['water_temp_c'] = timestep_ds['water_temp_c'] + 10
    test_ds_out = xr.concat(
        [
            test_ds,
            timestep_ds,
        ],
        dim='time',
    )
test_ds_out

436 ms ± 47.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [75]:
test_ds.water_temp_c.dims

('time', 'lat', 'lon')

In [76]:
test_ds.time[-1].item()

59

In [77]:
new_shape = tuple(list(test_ds.water_temp_c.values[0].shape) + [50])
new_shape

(25, 53, 50)

In [78]:
%%time
padding = np.full(new_shape, np.NaN)
#print(padding.shape)
#padding

CPU times: total: 0 ns
Wall time: 0 ns


In [80]:
test_ds.time[-1].item()

59

In [85]:
%%timeit
new_ds = xr.Dataset(
    data_vars={k: (('lat', 'lon', 'time'), padding) for k in list(test_ds.data_vars)},
    coords={
        'lat': test_ds.coords['lat'].values, 
        'lon': test_ds.coords['lon'].values, 
        'time': np.arange(test_ds.time[-1].item(), test_ds.time[-1].item() + new_shape[-1]),
    },
)

new_ds_concat = xr.concat(
    [
        test_ds,
        new_ds,
    ],
    dim='time',
)
new_ds_concat

4.41 ms ± 325 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [82]:
write_time = int(test_ds.time[-1].item() + 1)
new_ds_concat.sel(time=write_time)['water_temp_c']

In [83]:
%%timeit
for i in range(50):
    # write time index
    wt = int(test_ds.time[-1].item() + 1)
    new_ds_concat.isel(time=wt)['water_temp_c'] = new_ds_concat.isel(time=(wt - 1))['water_temp_c'] * 2
new_ds_concat

83.9 ms ± 4.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [30]:
new_ds

# Init TSM moel

In [None]:
dir(clearwater_modules.tsm)

In [None]:
clearwater_modules.tsm.EnergyBudget.get_state_variables()

In [None]:
%%time
tsm_model = clearwater_modules.tsm.EnergyBudget(
    initial_state_values={
        'water_temp_c': test_ds.water_temp_c,
        'surface_area': test_ds.surface_area,
        'volume': test_ds.volume,
    },
)

# Define not-default approach

In [None]:
from clearwater_modules.utils import (
    Variable,
    sorter,
)

In [None]:
# V1: as it was before 11/13/2023
@numba.jit(forceobj=True)
def v1_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    for var in compute_order:
        input_vars: list[str] = sorter.get_process_args(var.process)
        input_dataset[var.name] = xr.apply_ufunc(
            var.process,
            *[input_dataset[name] for name in input_vars],
        )
    return input_dataset

# V2: no JIT compile 
def v2_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    for var in compute_order:
        input_vars: list[str] = sorter.get_process_args(var.process)
        input_dataset[var.name] = xr.apply_ufunc(
            var.process,
            *[input_dataset[name] for name in input_vars],
        )
    return input_dataset

In [None]:
# V3: Trying to full JIT compile
import numba
import numpy as np

def get_args(var: Variable) -> tuple[callable, list[np.ndarray]]:
    func = var.process
    args = sorter.get_process_args(var.process)
    return (var.name, func, args)

def get_arrays(input_dataset: xr.Dataset, arg_names: list[str]) -> tuple[np.array]:
    return tuple([input_dataset[name].values.astype(np.float64) for name in arg_names])

def get_inputs(input_dataset: xr.Dataset, var: Variable):
    name, func, args = get_args(var)
    arrays = get_arrays(input_dataset, args)
    return name, func, args, arrays

def v3_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    #func_args: list[tuple[callable, list[str]]] = map(get_args, compute_order)
    inputs = map(lambda x: get_inputs(input_dataset, x), compute_order)
    for name, func, args, arrays in inputs:
        array = func(*arrays)
        input_dataset[name] = (input_dataset.dims, array)
    return input_dataset


@numba.njit
def inner_loop(inputs: list[tuple[str, callable, list[str], tuple[np.ndarray]]],
) -> dict:
    out_dict = {}
    for name, func, args, arrays in inputs:
        out_dict[name] = func(*arrays)
    return out_dict

# DOES NOT WORK due to heterogenous inputs into `inner_loop` -> fundamental numba problem.
def v4_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    #func_args: list[tuple[callable, list[str]]] = map(get_args, compute_order)
    inputs = list(map(lambda x: get_inputs(input_dataset, x), compute_order))
    out_dict = inner_loop(inputs)
    for name, array in out_dict.items():
        input_dataset[name] = (input_dataset.dims, array)
    return input_dataset


# Now an attempt to clean up V3
def _prep_inputs(
    input_dataset: xr.Dataset,
    var: Variable,
) -> tuple[str, callable, list[np.ndarray]]:
    """Prepare inputs for computation. This is used to speed up computation.

    Returns:
        A tuple with (
            name:str, 
            function:callable, 
            args:tuple[str], 
            arrays:list[np.ndarray]
        )
    """
    args: list[str] = sorter.get_process_args(var.process)
    return (
        var.name,
        var.process,
        [input_dataset[name].values for name in args],
    )


def v5_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    """Iterate over the computation order."""
    inputs = map(lambda x: _prep_inputs(input_dataset, x), compute_order)
    dims = input_dataset.dims

    for name, func, arrays in inputs:
        array: np.ndarray = func(*arrays)
        input_dataset[name] = (dims, array)

    return input_dataset

In [None]:
type(tsm_model.dataset.dims)

# Run compute iterations with each version

**Findings:**
* The majority of the timestep is not running calculations, but rather `xarray` IO.
* As is (`forceobj=True`), JIT compile vs non-JIT compile are about the same speed
* `map_blocks()` instead of `apply_ufunc` seemed like it could be a decent approach, however, this passes `xr.DataArray`s into our process functions, which are currently JIT compiled, which throws an error.

In [None]:
%%timeit
tsm_model.increment_timestep()

In [None]:
%%timeit
v1_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

In [None]:
%%timeit
v2_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

In [None]:
%%timeit
v3_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

In [None]:
%%timeit
v4_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

In [None]:
%%timeit
v5_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)