In [7]:
# tests_and_example.py
import numpy as onp
import jax
import jax.numpy as jnp
from functools import partial

import vsem as vsem_numpy
from vsem_jax import (
    make_vsem_input_from_named, 
    solve_vsem_jax, 
    get_vsem_driver,
    build_vectorized_partial_forward_model
)

In [5]:
# driver is 1D numpy or jnp array:
_, driver = get_vsem_driver(200)   # use your old helper or generate externally

params = {"lue": 0.002, "gamma": 0.4, "veg_init": 3.0}
vsem_input = make_vsem_input_from_named(params, driver)
out = solve_vsem_jax(vsem_input)  # jnp array shape (n_days, 5)


In [8]:
fwd = build_vectorized_partial_forward_model(driver, ["lue", "gamma", "veg_init"])
arr = jnp.array([[0.002, 0.4, 3.0], [0.0025, 0.45, 2.7]])  # (n_runs, 3)
batch_out = fwd(arr)  # (n_runs, n_days, 5)

In [9]:
batch_out.shape

(2, 200, 5)

In [14]:
# test_vsem_jax.py
import numpy as np
import jax.numpy as jnp
from vsem_jax import (solve_vsem_jax, make_vsem_input_from_named,
                      build_vectorized_partial_forward_model,
                      get_vsem_default_pars_dict, get_vsem_par_names,
                      get_vsem_output_names)
# import your original numpy implementation (the code you pasted)
import vsem as vsem_np  # assumes the old code is in vsem.py

def get_par_default_array_lowercase():
    d = get_vsem_default_pars_dict()
    return np.array([d[name] for name in vsem_np.get_vsem_par_names()])  # careful: vsem_np names are upper-case; but their order matches

# 1) Test single-run equivalence
def test_single_run_equivalence():
    # generate a synthetic PAR driver
    n_days = 200
    time_steps, PAR = vsem_np.get_PAR_driver(n_days, rng=np.random.default_rng(123))

    # default parameters as provided by original implementation
    par_default = vsem_np.get_vsem_default_pars()["value"].values  # length 11, original order

    # create named mapping keyed by lowercase canonical names (we align indices)
    canonical = [name.lower() for name in vsem_np.get_vsem_par_names()]
    par_named = {canonical[i]: par_default[i] for i in range(len(par_default))}

    vsem_input = make_vsem_input_from_named(par_named, PAR)
    out_jax = np.array(solve_vsem_jax(vsem_input))  # convert to numpy for comparison
    out_np = vsem_np.solve_vsem(PAR, par_default)

    # Compare all outputs up to small tolerance
    assert out_jax.shape == out_np.shape
    close = np.allclose(out_jax, out_np, atol=1e-8, rtol=1e-6)
    print("Single-run match:", close)
    if not close:
        diff = np.max(np.abs(out_jax - out_np))
        print("Max abs diff:", diff)
    return close

# 2) Test vectorized forward model
def test_vectorized_forward_model():
    n_days = 100
    _, PAR = vsem_np.get_PAR_driver(n_days, rng=np.random.default_rng(42))

    # We'll vary three parameters: lue, gamma, veg_init
    par_names = ["lue", "gamma", "veg_init"]
    fwd = build_vectorized_partial_forward_model(PAR, par_names)

    # create 4 runs with different values
    arr = np.array([
        [0.002, 0.4, 3.0],
        [0.0025, 0.4, 2.5],
        [0.0015, 0.45, 4.0],
        [0.002, 0.35, 3.5]
    ])
    out_jax_batch = np.array(fwd(arr))  # shape (4, n_days, 5)

    # brute-force compute using original solver for each run
    out_np_batch = []
    # default full parameter vector
    default_pars = vsem_np.get_vsem_default_pars()["value"].values
    all_par_names = vsem_np.get_vsem_par_names()
    # prepare index mapping
    # idxs = [all_par_names.index(name.upper()) for name in par_names]  # original uses upper-case names
    idxs = [2, 3, 8]
    for row in arr:
        par_run = default_pars.copy()
        for j, idx in enumerate(idxs):
            par_run[idx] = row[j]
        out_np_batch.append(vsem_np.solve_vsem(PAR, par_run))
    out_np_batch = np.stack(out_np_batch, axis=0)

    ok = np.allclose(out_jax_batch, out_np_batch, atol=1e-8, rtol=1e-6)
    print("Vectorized batch match:", ok)
    if not ok:
        print("Max diff:", np.max(np.abs(out_jax_batch - out_np_batch)))
    return ok



In [15]:
print("Single run test:", test_single_run_equivalence())

Single-run match: True
Single run test: True


In [16]:
print("Vectorized test:", test_vectorized_forward_model())

Vectorized batch match: True
Vectorized test: True
