In [None]:
%env EQX_ON_ERROR=nan
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2
# %env JAX_PLATFORMS=cpu

In [None]:
import jax

jax.config.update("jax_enable_x64", True)

from kinetix.reactions import (
    NitrateReduction,
    NitriteReduction,
    Species,
    AerobicRespiration,
    Reactions,
    NNReaction,
)
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt
import diffrax
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
import dataclasses
from typing import Any
from functools import partial
import pytensor.tensor as pt
import pytensor
import pymc as pm
import nutpie
import xarray as xr
import optimistix
import equinox as eqx
import arviz
import seaborn as sns


def rhs(t, log_state: Species, reactions: Reactions):
    rates = reactions.specific_rate(log_state)
    return Species(
        nitrate=3600 * rates.nitrate,
        nitrite=3600 * rates.nitrite,
        biomass=3600 * rates.biomass,
        oxygen_liq=3600 * rates.oxygen_liq,
    )


cpu_device = jax.devices("cpu")[0]


def make_solver(*, t_max, t_points, rtol=1e-8, atol=1e-8, solver=None, t0=0, dt0=None):
    if solver is None:
        # solver = diffrax.Dopri5()
        solver = diffrax.Tsit5()
        # root_finder = optimistix.Dogleg(rtol=1e-9, atol=1e-9, norm=optimistix.two_norm)
        # solver = diffrax.Kvaerno3(root_find_max_steps=10, root_finder=root_finder)
        # solver = diffrax.Kvaerno3()

    term = ODETerm(rhs)
    stepsize_controller = diffrax.PIDController(
        rtol=rtol,
        atol=atol,
        # norm=optimistix.two_norm,
    )
    t_vals = SaveAt(ts=t_points)

    @eqx.filter_jit(device=cpu_device)
    def solve(y0: Species, reactions: Reactions):
        def root_rhs(y, args):
            value = Species(
                nitrate=y0.nitrate,
                nitrite=y,
                oxygen_liq=y0.oxygen_liq,
                biomass=y0.biomass,
            )
            out = rhs(0.0, value, reactions).nitrite
            return out

        # y0_root = jnp.array(-15.0)
        # root_finder = optimistix.Bisection(rtol=1e-7, atol=1e-7, norm=optimistix.two_norm)
        # root_finder = optimistix.Bisection(rtol=1e-7, atol=1e-7)
        # sol = optimistix.root_find(root_rhs, root_finder, y0_root, options=dict(lower=-30, upper=-5))

        # y0 = Species(
        #    nitrate=y0.nitrate,
        #    nitrite=sol.value,
        #    oxygen_liq=y0.oxygen_liq,
        #    biomass=y0.biomass,
        # )

        result = diffeqsolve(
            term,
            solver,
            t0=t0,
            t1=t_max,
            dt0=dt0,
            y0=y0,
            saveat=t_vals,
            args=reactions,
            stepsize_controller=stepsize_controller,
            max_steps=1024 * 32,
        )
        return result

    return solve

In [None]:
data = xr.open_dataset("../data/data_Qu_2015.nc").load()

In [None]:
time = data.time
nitrite = data["NO2-"]
biomass = data["cells"]

In [None]:
# Remove outlier
biomass.values[0] = np.nan

In [None]:
coords = {
    "time_measured": time.values,
    "time_dense": jnp.linspace(time.values.min(), time.values.max(), num=50),
}

with pm.Model(coords=coords) as model:
    with pm.Model("y0"):
        # TODO remove first cell observation?
        biomass_log = pm.StudentT("biomass_log", mu=21.0, sigma=0.5, nu=10)
        # nitrite_log = pm.Normal("nitrite_log", mu=-15.0, sigma=5)
        nitrite_log = pm.StudentT("nitrite_log", mu=-18.0, sigma=5, nu=10)
        y0_pt = Species(
            nitrate=np.log(2e-3),
            nitrite=-16.0,
            biomass=biomass_log,
            oxygen_liq=np.log(9.3e-5),
        )
        # y0_pt = Species(
        #    nitrate=2e-3,
        #    nitrite=np.exp(-20),  # TODO
        #    biomass=pt.exp(biomass_log),
        #    oxygen_liq=9.3e-5,
        # )

    if False:
        with pm.Model("nn"):
            key = jax.random.key(0)
            nn_jax = eqx.nn.MLP(
                4,
                4,
                width_size=4,
                depth=1,
                activation=jax.nn.gelu,
                key=key,
            )

            count = 0

            def make_node(x):
                global count
                count += 1
                if not eqx.is_inexact_array(x):
                    return x
                val = pm.Normal(f"val{count}", shape=x.shape)
                return val

            nn_pt = jax.tree.map(make_node, nn_jax)
            reaction_nn = NNReaction(nn_pt, Species.zeros())
    else:
        reaction_nn = None

    with pm.Model("nitrate_reduction"):
        oxygen_inhib_log = pm.Normal("oxygen_inhib_log", mu=-15, sigma=3)
        nu_max_log = pm.StudentT("nu_max_log", mu=-44, sigma=1.5, nu=10)
        nitrate_reduction = NitrateReduction(
            log_nu_max=nu_max_log,
            log_K=pt.as_tensor(np.log(5e-6)),
            log_oxygen_inhib=oxygen_inhib_log,
        )

    with pm.Model("nitrite_reduction"):
        oxygen_inhib_log = pm.Normal("oxygen_inhib_log", mu=-15, sigma=3)
        nu_max_log = pm.StudentT("nu_max_log", mu=-44, sigma=1.5, nu=10)
        nitrite_reduction = NitriteReduction(
            log_nu_max=nu_max_log,
            log_K=pt.as_tensor(np.log(5e-6)),
            log_oxygen_inhib=oxygen_inhib_log,
        )

    #    prior_Ks_log = stats.t(loc=-13, scale=2, df=10).rvs(10000)
    #    prior_rmax_log = stats.norm(loc=-44, scale=1.5).rvs(10000)
    with pm.Model("aerobic_respiration"):
        nu_max_log = pm.Normal("nu_max_log", mu=-44, sigma=1.5)
        K_log = pm.StudentT("K_log", mu=-13, sigma=2, nu=10)
        growth_yield_log = pm.StudentT(
            "growth_yield_log",
            mu=np.log(7e14),
            sigma=0.5,
            nu=10,
        )
        aerobic_respiration = AerobicRespiration(
            log_nu_max=nu_max_log,
            log_K=K_log,
            log_growth_yield=growth_yield_log,
        )

    reactions_pt = Reactions(
        reactions=[nitrate_reduction, nitrite_reduction, aerobic_respiration],
        nn_reaction=reaction_nn,
    )

    # Dense solver only used in deterministic for plotting
    t_points = coords["time_dense"]
    t_max = max(t_points)
    solve_fn_dense = make_solver(t_max=t_max, t_points=t_points)

    @pytensor.as_jax_op
    def solve_pt(y0, reactions):
        return solve_fn_dense(y0, reactions).ys

    solution = solve_pt(y0_pt, reactions_pt)

    with pm.Model("solution_dense"):
        fields = dataclasses.fields(solution)
        for field in fields:
            pm.Deterministic(
                field.name,
                pt.exp(getattr(solution, field.name)),
                dims="time_dense",
            )

    # Solve at time locations where it was measured
    t_points = coords["time_measured"]
    t_max = max(t_points)
    solve_fn_measured = make_solver(t_max=t_max, t_points=t_points)

    @pytensor.as_jax_op
    def solve_pt(y0, reactions):
        return solve_fn_measured(y0, reactions).ys

    solution = solve_pt(y0_pt, reactions_pt)

    with pm.Model("solution_measured"):
        fields = dataclasses.fields(solution)
        for field in fields:
            pm.Deterministic(
                field.name,
                pt.exp(getattr(solution, field.name)),
                dims="time_measured",
            )

    # Nitrite likelihood
    with pm.Model("nitrite"):
        time_idx, replicate_idx = nitrite.notnull().values.nonzero()

        if True:
            mu = solution.nitrite[time_idx]
            observed = nitrite.values[time_idx, replicate_idx]
            sigma_rel = pm.HalfNormal("sigma_rel", sigma=0.05)
            with_rel_error = pm.Normal(
                "with_rel_error", mu=mu, sigma=sigma_rel, shape=len(observed)
            )
            sigma_abs = pm.HalfNormal("micro_sigma_abs", sigma=1) * 1e-6
            pm.Normal(
                "y", mu=pt.exp(with_rel_error), sigma=sigma_abs, observed=observed
            )
        elif False:
            mu = pt.exp(solution.nitrite[time_idx])
            observed = nitrite.values[time_idx, replicate_idx]
            sigma = pm.HalfNormal("sigma", sigma=1) * 1e-4
            pm.Normal("y", mu=mu, sigma=sigma, observed=observed)
        else:
            mu = solution.nitrite[time_idx]
            observed = np.log(nitrite.values[time_idx, replicate_idx])
            sigma = pm.HalfNormal("sigma", sigma=0.1)
            dist = pm.StudentT.dist(mu=mu, sigma=sigma, nu=10)
            pm.Censored("y", dist, lower=-16, observed=np.clip(observed, -16, np.inf))

    # biomass likelihood
    with pm.Model("biomass"):
        time_idx, replicate_idx = biomass.notnull().values.nonzero()
        # mu = pt.exp(solution.biomass[time_idx] - np.log(1e11))
        mu = solution.biomass[time_idx]
        observed = np.log(biomass.values[time_idx, replicate_idx])  # / 1e11
        sigma = pm.HalfNormal("sigma", sigma=0.2)
        # sigma = 0.5
        pm.Normal("y", mu=mu, sigma=sigma, observed=observed)

In [None]:
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")

In [None]:
compiled.n_dim

In [None]:
reactions = jax.tree.map(lambda x: pm.draw(x), reactions_pt)
reactions
y0 = jax.tree.map(
    lambda x: x if isinstance(x, np.ndarray) or np.isscalar(x) else pm.draw(x), y0_pt
)


def rhs(t, log_state: Species, reactions: Reactions):
    rates = reactions.specific_rate(log_state)
    return Species(
        nitrate=3600 * rates.nitrate,
        nitrite=3600 * rates.nitrite,
        biomass=3600 * rates.biomass,
        oxygen_liq=3600 * rates.oxygen_liq,
    )

In [None]:
solver = diffrax.Tsit5()
# root_finder = optimistix.Dogleg(rtol=1e-9, atol=1e-9, norm=optimistix.two_norm)
# solver = diffrax.Kvaerno5(root_find_max_steps=10, root_finder=root_finder)

term = ODETerm(rhs)
stepsize_controller = diffrax.PIDController(
    rtol=1e-8,
    atol=1e-8,
    norm=optimistix.two_norm,
)


@jax.tree_util.register_dataclass
@dataclass
class ODEOutput:
    dense: diffrax.SubSaveAt
    measured: diffrax.SubSaveAt
    rhs: diffrax.SubSaveAt


subs = ODEOutput(
    dense=diffrax.SubSaveAt(ts=coords["time_dense"]),
    rhs=diffrax.SubSaveAt(ts=t_points, fn=rhs),
    measured=diffrax.SubSaveAt(ts=coords["time_measured"]),
)
t_vals = SaveAt(subs=subs)


result = diffeqsolve(
    term,
    solver,
    t0=0.0,
    t1=max(t_points),
    dt0=None,
    y0=y0,
    saveat=t_vals,
    args=reactions,
    stepsize_controller=stepsize_controller,
    max_steps=1024 * 32,
)

In [None]:
%%time
with model:
    prior = pm.sample_prior_predictive(draws=50, random_seed=42)

In [None]:
prior.prior["solution_dense::biomass"].plot.line(
    x="time_dense", hue="draw", col="chain", add_legend=False
)
plt.scatter(
    biomass.time,
    biomass.mean("replicate"),
)

In [None]:
prior.prior["solution_dense::nitrate"].plot.line(
    x="time_dense", hue="draw", col="chain", add_legend=False
)

In [None]:
prior.prior["solution_dense::nitrite"].plot.line(
    x="time_dense", hue="draw", col="chain", add_legend=False
)
plt.scatter(
    nitrite.time,
    nitrite.mean("replicate"),
)

In [None]:
prior.prior["solution_dense::oxygen_liq"].plot.line(
    x="time_dense", hue="draw", col="chain", add_legend=False
)

In [None]:
(prior.prior["solution_dense::nitrite"]).plot.line(
    x="time_dense", hue="draw", col="chain", add_legend=False
)

In [None]:
prior.prior["solution_dense::biomass"].isel(time_dense=0)

In [None]:
np.log(prior.prior["solution_dense::nitrite"].isel(time_dense=1))

In [None]:
func = compiled._make_logp_func()

In [None]:
%%time
func(compiled._make_initial_points(51))

In [None]:
compiled = compiled.with_transform_adapt(
    verbose=True,
    show_progress=True,
    max_patience=50,
    num_layers=20,
    # window_size=256,
    num_diag_windows=7,
    nn_width=32,
    nn_depth=2,
    debug_save_bijection=True,
)

sampler = nutpie.sample(
    compiled,
    tune=2000,
    chains=1,
    transform_adapt=True,
    store_unconstrained=True,
    store_gradient=True,
    store_divergences=True,
    # window_switch_freq=128,
    maxdepth=9,
    blocking=False,
    seed=42,
    progress_rate=1000,
    # max_energy_error=10,
)

In [None]:
def box_cox(x, lam):
    return (x**lam - 1) / lam


x = np.linspace(-10, 10, 1000)

In [None]:
plt.plot(x, box_cox(x, 0.2))

In [None]:
trace.to_netcdf("trace-nf-nn.nc")

In [None]:
trace = sampler.inspect()

In [None]:
trace.warmup_sample_stats.draw

In [None]:
out = sampler.abort()

In [None]:
grad = trace.warmup_sample_stats.transformed_gradient.isel(
    draw=slice(-1000, None), chain=0
)
draw = trace.warmup_sample_stats.transformed_position.isel(
    draw=slice(-1000, None), chain=0
)

In [None]:
loss = (grad + draw) ** 2

In [None]:
import scipy

In [None]:
diag = np.sqrt(draw.std("draw") / grad.std("draw"))

In [None]:
diag

In [None]:
plt.scatter(
    draw.isel(unconstrained_parameter=3),
    grad.isel(unconstrained_parameter=3),
)

In [None]:
loss.mean("draw").to_pandas().sort_values()

In [None]:
np.log10(trace.warmup_sample_stats.step_size).plot.line(x="draw")

In [None]:
_, fit, points = nutpie.transform_adapter._BIJECTION_TRACE[-1]

In [None]:
sns.heatmap(fit.bijections[0].bijection.bijections[1].params, center=0)

In [None]:
# sampler.abort()

In [None]:
arviz.plot_trace((trace.warmup_posterior), var_names=["sigma"], filter_vars="like")
plt.tight_layout()

In [None]:
arviz.plot_trace(
    (trace.warmup_posterior), var_names=["nitrite::sigma", "biomass::sigma"]
);

In [None]:
arviz.plot_trace((trace.warmup_posterior), var_names=["nn::val1"]);

In [None]:
arviz.plot_trace((trace.warmup_posterior), var_names=["nn::val2"]);

In [None]:
arviz.plot_trace((trace.warmup_posterior), var_names=["nn::val3"]);

In [None]:
arviz.plot_trace((trace.warmup_posterior), var_names=["nn::val4"]);

In [None]:
arviz.plot_trace(
    trace.warmup_posterior.isel(draw=slice(10, None)),
    var_names=[var.name for var in model.free_RVs],
)
plt.tight_layout();

In [None]:
trace.warmup_posterior.isel(draw=slice(-800, None)).assign_coords(**coords)[
    "solution_dense::biomass"
].plot.line(x="time_dense", col="chain", hue="draw", add_legend=False)
plt.scatter(
    biomass.time,
    biomass.mean("replicate"),
    zorder=1000,
)

In [None]:
trace.posterior.assign_coords(**coords)["solution_dense::biomass"].plot.line(
    x="time_dense", col="chain", hue="draw", add_legend=False
)
plt.scatter(
    biomass.time,
    biomass.mean("replicate"),
    zorder=1000,
)

In [None]:
trace.warmup_posterior.isel(draw=slice(-800, None)).assign_coords(**coords)[
    "solution_dense::nitrite"
].plot.line(x="time_dense", col="chain", hue="draw", add_legend=False)
plt.scatter(
    nitrite.time,
    nitrite.mean("replicate"),
    zorder=1000,
)
# plt.yscale("log")

In [None]:
np.exp(
    trace.warmup_posterior.isel(draw=slice(-20, None)).assign_coords(**coords)[
        "nitrite::with_rel_error"
    ]
).plot.line(
    x="nitrite::with_rel_error_dim_0", col="chain", hue="draw", add_legend=False
)
plt.scatter(
    np.arange(len(nitrite.time)) * 1.8,
    nitrite.mean("replicate"),
    zorder=1000,
)
plt.yscale("log")
plt.ylim(np.exp(-25), 1)

In [None]:
trace.warmup_posterior.isel(draw=slice(-20, None)).assign_coords(**coords)[
    "solution_dense::nitrite"
].plot.line(x="time_dense", col="chain", hue="draw", add_legend=False)
plt.scatter(
    nitrite.time,
    nitrite.mean("replicate"),
    zorder=1000,
)
plt.yscale("log")
plt.ylim(np.exp(-25), 1)

In [None]:
trace.warmup_posterior.isel(draw=slice(-50, None)).assign_coords(**coords)[
    "solution_dense::oxygen_liq"
].plot.line(x="time_dense", col="chain", hue="draw", add_legend=False)
# plt.yscale("log")
ax2 = plt.twinx()
data["O2"].plot.line(x="time", ax=ax2, marker="+");
# ax2.set_yscale("log")

In [None]:
sns.heatmap(trace.warmup_posterior.isel(draw=-1, chain=0)["nn::val1"], center=0)

In [None]:
sns.heatmap(trace.warmup_posterior.isel(draw=-1, chain=0)["nn::val3"], center=0)

In [None]:
trace.warmup_posterior.isel(draw=-1, chain=0)["nn::val2"].plot.line()

In [None]:
trace.warmup_posterior.isel(draw=-1, chain=0)["nn::val4"].plot.line()

In [None]:
arviz.plot_pair(
    trace.warmup_sample_stats.isel(draw=slice(-500, None)),
    var_names=["unconstrained_draw"],
    filter_vars="like",
);

In [None]:
arviz.plot_pair(
    trace.warmup_sample_stats.isel(draw=slice(-1000, None)),
    var_names=["transformed_position"],
    filter_vars="like",
);

In [None]:
arviz.plot_pair(
    trace.sample_stats,
    var_names=["unconstrained_draw"],
    filter_vars="like",
);

In [None]:
arviz.plot_pair(
    trace.sample_stats,
    var_names=["transformed_position"],
    filter_vars="like",
);