# Testing different computation engine approaches

Computation engine = the part of the code that is shared across all modules. AKA, the inherited `base.increment_timestep()` function, which current passes `Model.computation_order` into `utils.iter_computations()`.

Note that `Model.computation_order` takes an optional dependency injection that overrides `utils.iter_computation()`, we use this to test different approaches.

**Testing approaches:**
* V1 approach: Have a `utils.iter_computations` has the following decorator `numba.jit(forceobj=True)` and attempts to JIT our custom `base.Variable` arguments as well as `xr.Dataset`.
* V2 approach: The same as V1, except not JIT compiled at all.
* V3 approach: A true JIT compile by using `numpy`, then writing back to `xarray`.

In [1]:
import clearwater_modules
import xarray as xr

# Pull in test data

In [2]:
%%time
# get test xarray data
test_ds: xr.Dataset = xr.tutorial.load_dataset('air_temperature')
test_ds.attrs = {}
test_ds: xr.DataArray = test_ds.rename_vars(
    {'air': 'water_temp_c'}
).isel(time=0)
test_ds['surface_area'] = test_ds.water_temp_c / test_ds.water_temp_c
test_ds['volume'] = test_ds.water_temp_c / test_ds.water_temp_c

test_ds

CPU times: total: 281 ms
Wall time: 295 ms


# Init TSM moel

In [3]:
dir(clearwater_modules.tsm)

['EnergyBudget',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'constants',
 'dynamic_variables',
 'model',
 'processes',
 'state_variables',
 'static_variables']

In [4]:
clearwater_modules.tsm.EnergyBudget.get_state_variables()

[Variable(name='water_temp_c', long_name='Water temperature', units='degC', description='TSM state variable for water temperature', use='state', process=CPUDispatcher(<function t_water_c at 0x000001F667EF47C0>)),
 Variable(name='surface_area', long_name='Surface area', units='m^2', description='Surface area', use='state', process=<function mock_surface_area at 0x000001F6642FE8E0>),
 Variable(name='volume', long_name='Volume', units='m^3', description='Volume', use='state', process=<function mock_volume at 0x000001F667EF4CC0>)]

In [5]:
%%time
tsm_model = clearwater_modules.tsm.EnergyBudget(
    initial_state_values={
        'water_temp_c': test_ds.water_temp_c,
        'surface_area': test_ds.surface_area,
        'volume': test_ds.volume,
    },
)

Initializing from dicts...
Model initialized from input dicts successfully!.
CPU times: total: 78.1 ms
Wall time: 72.7 ms


# Define not-default approach

In [6]:
from clearwater_modules.utils import (
    Variable,
    sorter,
)

In [7]:
# V2: no JIT compile 
def v2_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    for var in compute_order:
        input_vars: list[str] = sorter.get_process_args(var.process)
        input_dataset[var.name] = xr.apply_ufunc(
            var.process,
            *[input_dataset[name] for name in input_vars],
        )
    return input_dataset

In [8]:
# V3: Trying to full JIT compile
import numba
import numpy as np

def get_args(var: Variable) -> tuple[callable, list[np.ndarray]]:
    func = var.process
    args = sorter.get_process_args(var.process)
    return (var.name, func, args)

def get_arrays(input_dataset: xr.Dataset, arg_names: list[str]) -> list[np.array]:
    return [input_dataset[name].values for name in arg_names]

def get_inputs(input_dataset: xr.Dataset, var: Variable):
    name, func, args = get_args(var)
    arrays = get_arrays(input_dataset, args)
    return name, func, args, arrays

def v3_iter_computations(
    input_dataset: xr.Dataset,
    compute_order: list[Variable],
) -> xr.Dataset:
    #func_args: list[tuple[callable, list[str]]] = map(get_args, compute_order)
    inputs = map(lambda x: get_inputs(input_dataset, x), compute_order)
    for name, func, args, arrays in inputs:
        array = func(*arrays)
        input_dataset[name] = (input_dataset.dims, array)
    return input_dataset

# Run compute iterations with each version

**Findings:**
* The majority of the timestep is not running calculations, but rather `xarray` IO.
* As is (`forceobj=True`), JIT compile vs non-JIT compile are about the same speed
* `map_blocks()` instead of `apply_ufunc` seemed like it could be a decent approach, however, this passes `xr.DataArray`s into our process functions, which are currently JIT compiled, which throws an error.

In [13]:
%%timeit
tsm_model.increment_timestep()

47.4 ms ± 3.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit
clearwater_modules.utils.iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

36 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%%timeit
v2_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

35.8 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
%%timeit
v3_iter_computations(
    tsm_model.dataset.isel(time_step=-1),
    tsm_model.computation_order,
)

19 ms ± 924 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
