# Xarray development notebook

**Author:** Xavier R Nogueira

**Things to explore:**
* Can we dynamically feed slices of an xarray into numba JIT functions using `xarray.apply_ufunc`?

**Performance Notes:**
* Numba JIT is slower w/o compiling a loop.
* Parallelizing via numba is slower at 100k iters (for a simple equation), but slightly faster at 1M.
* JIT compiling AROUND a JIT compiled function to apply the loop is by far the fastest.
* Could we loop and blast thru writing results, throw them in a stack of sorts, and async write into xarray? Avoiding the writing bottle neck could be fast.
* Another idea is to loop over a dictionary of functions_names and params. Pop from the dict and add to another dict. Basically an ordered dict would control the flow. **I like this idea, but how can we make it work for xarray**.

**Construction Notes:**
* One can use `xarray.apply_ufunc` programatically via `*args` syntax. This means we can rapidly evaluate a time step.
* It might be faster to project all constants into xarray first, passing one in via the `*args` syntax slows things down, I'm assuming it's because each time it's broadcasting.

In [1]:
import numba
import xarray as xr
import clearwater_modules_python

In [2]:
dir(clearwater_modules_python)

['__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 'shared_equations',
 'tsm']

In [3]:
dir(clearwater_modules_python.tsm)

['EnergyBudget',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'constants',
 'equations',
 'model']

In [4]:
dir(clearwater_modules_python.tsm.equations)

['__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'dTdt_sediment_c',
 'dTdt_water_c',
 'density_air',
 'emissivity_air',
 'mixing_ratio_air',
 'numba',
 'q_latent',
 'q_net',
 'q_sediment',
 'q_sensible',
 'wind_function']

In [5]:
clearwater_modules_python.tsm.equations.q_sensible.__annotations__

{'wind_kh_kw': float,
 'ri_function': float,
 'cp_air': float,
 'density_water': float,
 'wind_function': float,
 'air_temp_k': float,
 'water_temp_k': float,
 'return': float}

# Numba comparison

In [6]:
TEST_ITERS = 1000000

In [7]:
def q_sensible(
    wind_kh_kw: float,
    ri_function: float,
    cp_air: float,
    density_water: float,
    wind_function: float,
    air_temp_k: float,
    water_temp_k: float,
) -> float:
    # TODO: check if the return units are correct
    """Sensible heat flux (W/m2).

    Args:
        wind_kh_kw: Diffusivity ratio (unitless)
        ri_function: Richardson number (unitless)
        cp_air: Specific heat of air (J/kg/K)
        density_water: Water density (kg/m^3)
        wind_function: Wind function (unitless)
        air_temp_k: Air temperature (K)
        water_temp_k: Water temperature (K)
    """
    return (
        wind_kh_kw *
        ri_function *
        cp_air * density_water * wind_function *
        (air_temp_k - water_temp_k)
    )

In [8]:
%%time
for i in range(TEST_ITERS):
    clearwater_modules_python.tsm.equations.q_sensible(
        float(i),
        float(i*2),
        float(i*3),
        float(i*4),
        float(i*5),
        float(i*6),
        float(i*7),
    )

CPU times: total: 1.33 s
Wall time: 1.55 s


In [9]:
%%time
for i in range(TEST_ITERS):
    q_sensible(
        float(i),
        float(i*2),
        float(i*3),
        float(i*4),
        float(i*5),
        float(i*6),
        float(i*7),
    )

CPU times: total: 781 ms
Wall time: 1.02 s


In [33]:
%%time
@numba.njit
def iter_numba(func: callable):
    for i in range(TEST_ITERS):
        func(
            float(i),
            float(i*2),
            float(i*3),
            float(i*4),
            float(i*5),
            float(i*6),
            float(i*7),
        )
iter_numba(clearwater_modules_python.tsm.equations.q_sensible)

CPU times: total: 109 ms
Wall time: 138 ms


In [12]:
ds = xr.tutorial.open_dataset('air_temperature')
ds

In [36]:
test_list = [
        ds.air,
        ds.air*2,
        ds.air*3,
        ds.air*4,
        ds.air*5,
        ds.air*6,
        ds.air*6,
]

In [37]:
%%time
xr.apply_ufunc(
        clearwater_modules_python.tsm.equations.q_sensible,
        *test_list,
    )

CPU times: total: 62.5 ms
Wall time: 91.8 ms


In [30]:
%%time
xr.apply_ufunc(
        clearwater_modules_python.tsm.equations.q_sensible,
        ds.air,
        ds.air*2,
        ds.air*3,
        ds.air*4,
        ds.air*5,
        ds.air*6,
        ds.air*7,
        dask='parrallelized', 
    )

CPU times: total: 125 ms
Wall time: 130 ms


In [15]:
out