In [None]:
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt

import bgk

In [None]:
from typing import Callable, TypeVar, Generic

T = TypeVar("T")


class FormulaChain(Generic[T]):
    def __init__(self, **formulae: T | Callable[..., T]) -> None:
        for var_name, formula in formulae.items():
            if not callable(formula):
                self.__dict__[var_name] = formula
            else:
                param_var_names = formula.__code__.co_varnames[: formula.__code__.co_argcount]

                def formula_wrapper(
                    _captured_param_var_names=param_var_names,
                    _captured_var_name=var_name,
                    _captured_formula=formula,
                    **var_vals: T,
                ) -> float:
                    # see https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result for why these captures are necessary
                    # TODO cleanup by extracting this into a class that defines __call__
                    params = {}
                    for name in _captured_param_var_names:
                        if name in var_vals:
                            params[name] = var_vals[name]
                        elif name in self.__dict__ and not callable(self.__dict__[name]):
                            params[name] = self.__dict__[name]
                        elif name in self.__dict__:
                            params[name] = self.__dict__[name](**var_vals)
                        else:
                            raise RuntimeError(f"couldn't determine value of '{name}' when trying to calculate '{_captured_var_name}'")
                    return _captured_formula(**params)

                self.__dict__[var_name] = formula_wrapper

    def __getitem__(self, var_name: str) -> T | Callable[..., T]:
        return self.__dict__[var_name]

In [None]:
input = bgk.RunManager("/mnt/lustre/IAM851/jm1667/psc-runs/case1/trials/exact/B00.25-n128").run_input

In [None]:
formulae_case_2 = FormulaChain(
    h0=0.9,
    k=0.9,
    xi=0.1,
    B0=0.25,
    f=lambda w, l, p, h0, k, xi: np.pi**-1.5 * np.exp(-w) * (1 - h0 * np.exp(-k * l**2 - xi * p**2)),
    w=lambda v_rho, v_phi, v_x, psi: 0.5 * (v_rho**2 + v_phi**2 + v_x**2) - psi,
    l=lambda rho, v_phi, A_phi: 2 * rho * (v_phi - A_phi),
    p=lambda v_x, A_x: v_x - A_x,
    A_phi=lambda rho, B0: 0.5 * rho * B0,
    A_x=0.1,
    psi=lambda rho: xr.apply_ufunc(input.interpolate_value, rho, "Psi", vectorize=True),
)

In [None]:
class Distribution:
    def __init__(self, formulae: FormulaChain[xr.DataArray], *derived_var_names: str, **var_vals: np.ndarray) -> None:
        var_vals = {name: xr.DataArray(val, coords=[(name, val)]) for name, val in var_vals.items()}
        self.data = xr.merge(formulae[name](**var_vals).rename(name) for name in derived_var_names)

    def __getitem__(self, key: str) -> xr.DataArray:
        return self.data[key]

In [None]:
dist = Distribution(
    formulae_case_2,
    "f",
    rho=np.linspace(0, input["rho"].max(), 51, endpoint=False),
    v_rho=np.linspace(-3, 3, 101),
    v_phi=np.linspace(-3, 3, 101),
    v_x=np.linspace(-3, 3, 101),
)

In [None]:
f_vx_vphi = dist["f"].integrate("v_rho").interp(rho=0.07)
f_vx_vphi_normalized = f_vx_vphi / f_vx_vphi.integrate("v_x")

plt.close()
fig, axs = plt.subplots(ncols=2)
f_vx_vphi.plot(ax=axs[0])
f_vx_vphi_normalized.plot(ax=axs[1])
fig.tight_layout()
plt.show()