# Parameter estimation with Photosynthesis: Comparing Uniform Sampling and Gradient Descent

Our goal is to estimate values for the parameter Vcmax that will make the model match to observational data. The approach is as follows:

1. Define a version of `ci_func` (the model) that takes Vcmax as input (and all other inputs as constants). Note that `Je` will be derived as `Je = Jmax = 1.67 * Vcmax`. 
2. Use random sampling to estimate values of Vcmax.
3. Use gradient descent to estimate values of Vcmax. 


## Defining the function

First, we need to rewrite `ci_func` such that `vcmax_z` is an input parameter. We had previously defined it as a constant within the function.

In [4]:
import math
import numpy as np
from jax import jit
import jax.numpy as jnp



def hybrid(x0, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c):
    eps = 1e-2
    eps1 = 1e-4
    itmax = 40
    iter = 0
    tol, minx, minf = 0.0, 0.0, 0.0

    f0, gs_mol, _ = ci_func(x0, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c)

    if f0 == 0.0:
        return x0, gs_mol, iter

    minx = x0
    minf = f0
    x1 = x0 * 0.99

    f1, gs_mol, _ = ci_func(x1, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c)

    if f1 == 0.0:
        return x1, gs_mol, iter

    if f1 < minf:
        minx = x1
        minf = f1

    while True:
        iter += 1
        dx = -f1 * (x1 - x0) / (f1 - f0)
        x = x1 + dx
        tol = abs(x) * eps

        if abs(dx) < tol:
            return x, gs_mol, iter

        x0 = x1
        f0 = f1
        x1 = x

        f1, gs_mol, _ = ci_func(x1, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c)

        if f1 < minf:
            minx = x1
            minf = f1

        if abs(f1) <= eps1:
            return x1, gs_mol, iter

        if f1 * f0 < 0.0:
            x, gs_mol = brent(
                x0,
                x1,
                f0,
                f1,
                tol,
                p,
                iv,
                c,
                gb_mol,
                vcmax_z,
                cair,
                oair,
                lmr_z,
                par_z,
                rh_can,
                gs_mol,
            )
            return x, gs_mol, iter

        if iter > itmax:
            f1, gs_mol, _ = ci_func(
                minx, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c
            )
            break

    return x0, gs_mol, iter


def brent(
    x1,
    x2,
    f1: float,
    f2: float,
    tol,
    ip,
    iv,
    ic,
    gb_mol,
    vcmax_z,
    cair,
    oair,
    lmr_z,
    par_z,
    rh_can,
    gs_mol,
):
    itmax = 20
    eps = 1e-2
    iter = 0
    a = x1
    b = x2
    fa = f1
    fb = f2

    if (fa > 0 and fb > 0) or (fa < 0 and fb < 0):
        print("root must be bracketed for brent")
        raise ValueError("f(a) and f(b) must have opposite signs for Brent's method.")

    c = b
    fc = fb

    while iter != itmax:
        iter += 1
        if (fb > 0 and fc > 0) or (fb < 0 and fc < 0):
            c = a
            fc = fa
            d = b - a
            e = d

        if abs(fc) < abs(fb):
            a = b
            b = c
            c = a
            fa = fb
            fb = fc
            fc = fa

        tol1 = 2 * eps * abs(b) + 0.5 * tol
        xm = 0.5 * (c - b)

        if abs(xm) <= tol1 or fb == 0:
            x = b
            return x, gs_mol

        if abs(e) >= tol1 and abs(fa) > abs(fb):
            s = fb / fa

            if a == c:
                p = 2 * xm * s
                q = 1 - s
            else:
                q = fa / fc
                r = fb / fc
                p = s * (2 * xm * q * (q - r) - (b - a) * (r - 1))
                q = (q - 1) * (r - 1) * (s - 1)

            if p > 0:
                q = -q

            p = abs(p)

            if 2 * p < min(3 * xm * q - abs(tol1 * q), abs(e * q)):
                e = d
                d = p / q
            else:
                d = xm
                e = d
        else:
            d = xm
            e = d

        a = b
        fa = fb

        if abs(d) > tol1:
            b = b + d
        else:
            b = b + jnp.copysign(jnp.array([tol1]), jnp.array([xm]))[0]

        fb, gs_mol, _ = ci_func(
            b, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, ip, iv, ic
        )

        if fb == 0:
            break

    if iter == itmax:
        print("brent exceeding maximum iterations", b, fb)

    x = b
    return x, gs_mol


def quadratic_roots(a, b, c):
    sqrt_discriminant = jnp.sqrt(jnp.array([b**2 - 4 * a * c]))[0]
    root1 = (-b - sqrt_discriminant) / (2 * a)
    root2 = (-b + sqrt_discriminant) / (2 * a)
    return root1, root2


def ci_func(
    ci,
    lmr_z,
    par_z,
    gb_mol,
    vcmax_z,
    cair,
    oair,
    rh_can,
    p,
    iv,
    c,
    # vcmax_z,
    c3flag=True,
    stomatalcond_mtd=1,
):
    # Constants
    forc_pbot = 121000.0
    medlynslope = 6.0
    medlynintercept = 10000.0
    # vcmax_z = 62.5
    cp = 4.275
    kc = 40.49
    ko = 27840.0
    qe = 1.0
    tpu_z = 31.5
    kp_z = 1.0
    bbb = 100.0
    mbb = 9.0
    theta_cj = 0.98
    theta_ip = 0.95
    stomatalcond_mtd_medlyn2011 = 1
    stomatalcond_mtd_bb1987 = 2

    # THIS ASSUMES FULL IRRADIANCE
    je = vcmax_z * 1.67

    # C3 or C4 photosynthesis
    if c3flag:
        ac = vcmax_z * max(ci - cp, 0.0) / (ci + kc * (1.0 + oair / ko))
        aj = je * max(ci - cp, 0.0) / (4.0 * ci + 8.0 * cp)
        ap = 3.0 * tpu_z
    else:
        ac = vcmax_z
        aj = qe * par_z * 4.6
        ap = kp_z * max(ci, 0.0) / forc_pbot

    # Gross photosynthesis
    aquad = theta_cj
    bquad = -(ac + aj)
    cquad = ac * aj
    r1, r2 = quadratic_roots(aquad, bquad, cquad)
    ai = min(r1, r2)

    aquad = theta_ip
    bquad = -(ai + ap)
    cquad = ai * ap
    r1, r2 = quadratic_roots(aquad, bquad, cquad)
    ag = max(0.0, min(r1, r2))

    # Net photosynthesis
    an = ag - lmr_z

    # [Anthony] Note that I've removed the line in the original code that checks for negative net photosynthesis.
    # if an < 0.0:
    #     # print("NEGATIVE PHOTOSYNTHESIS")
    #     fval = 0.0
    #     return fval, None, None

    # Quadratic gs_mol calculation
    cs = cair - 1.4 / gb_mol * an * forc_pbot
    if stomatalcond_mtd == stomatalcond_mtd_medlyn2011:
        term = 1.6 * an / (cs / forc_pbot * 1.0e06)
        aquad = 1.0
        bquad = -(
            2.0 * (medlynintercept * 1.0e-06 + term)
            + (medlynslope * term) ** 2 / (gb_mol * 1.0e-06 * rh_can)
        )
        cquad = (
            medlynintercept**2 * 1.0e-12
            + (
                2.0 * medlynintercept * 1.0e-06
                + term * (1.0 - medlynslope**2 / rh_can)
            )
            * term
        )
        r1, r2 = quadratic_roots(aquad, bquad, cquad)
        gs_mol = max(r1, r2) * 1.0e06
    elif stomatalcond_mtd == stomatalcond_mtd_bb1987:
        aquad = cs
        bquad = cs * (gb_mol - bbb) - mbb * an * forc_pbot
        cquad = -gb_mol * (cs * bbb + mbb * an * forc_pbot * rh_can)
        r1, r2 = quadratic_roots(aquad, bquad, cquad)
        gs_mol = max(r1, r2)
    else:
        gs_mol = 0.0

    # Derive new estimate for ci
    fval = ci - cair + an * forc_pbot * (1.4 / gb_mol + 1.6 / gs_mol)

    return fval, gs_mol, an


def main(
    ci,
    lmr_z,
    par_z,
    gb_mol,
    vcmax_z,
    cair,
    oair,
    rh_can,
    p,
    iv,
    c,
    c3flag=True,
    stomatalcond_mtd=1,
):
    ci_val, gs_mol, _ = hybrid(
        ci, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c
    )

    _, _, an = ci_func(ci_val, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c)

    return ci_val, gs_mol, an


## Visualizing the task

Our goal here is to find a value of `vcmax_z` such that the model's photosynthesis curve matches the observed data. To visualize this comparison, we need to graph the model's predictions together with an interpolated version of the observed data. 

In [5]:
import plotly.graph_objects as go
import pandas as pd

def get_model_values(cair_values: np.ndarray, vcmax_z: float):
    ci_values = np.zeros_like(cair_values)
    an_values = np.zeros_like(cair_values)
    gs_mol_values = np.zeros_like(cair_values)

    ci = 35
    lmr_z = 4
    par_z = 500
    gb_mol = 50_000
    cair = 45
    oair = 21000
    rh_can = 0.40
    p = 1
    iv = 1
    c = 1


    outputs = [main(ci, lmr_z, par_z, gb_mol, vcmax_z, cair, oair, rh_can, p, iv, c) for cair in cair_values]
    ci_values = [output[0] for output in outputs]
    an_values = [output[2] for output in outputs]
    gs_mol_values = [output[1] for output in outputs]

    return {
        "cair": cair_values,
        "ci": jnp.array(ci_values),
        "an": jnp.array(an_values),
        "gs_mol": jnp.array(gs_mol_values),
    }


def get_observed_values(filename: str):
    df = pd.read_csv(filename)
    new_df = df[["Photo", "Ci"]].copy()
    new_df["Photo"] = new_df["Photo"] / 4
    new_df["Ci"] = new_df["Ci"] / 10
    new_df = new_df.sort_values(by=["Ci"])
    new_df = new_df.rename(columns={"Photo": "an", "Ci": "ci"})
    return {
        'ci': jnp.array(new_df['ci'].values),
        'an': jnp.array(new_df['an'].values),
    }



def linear_interp_with_extrapolation(x, y, x_new):
    """
    This is a custom version of linear interpolation that's good for JAX.
    Scipy and Numpy's interpolation functions won't work out-of-the-box with autodiff.
    """
    x_min = jnp.min(x)
    x_max = jnp.max(x)
    i = jnp.searchsorted(x, x_new) - 1
    i = jnp.clip(i, 0, len(x) - 2) 
    y_lo = y[i]
    y_hi = y[i+1]
    x_lo = x[i]
    x_hi = x[i+1]

    y_new = y_lo + (x_new - x_lo) * (y_hi - y_lo) / (x_hi - x_lo)

    # Extrapolate using a constant if needed
    y_new = jnp.where(x_new < x_min, y[0], y_new)  
    y_new = jnp.where(x_new > x_max, y[-1], y_new) 

    return y_new


Let's choose an arbitrary Vcmax value and graph the model results with the observed results.

In [6]:
vcmax_z = 30
cair_values = np.linspace(10, 100, 19)
model_values = get_model_values(cair_values, vcmax_z)

observed_values = get_observed_values("2021-07-15-Me2-05.csv")


x = np.linspace(model_values['ci'][0], model_values['ci'][-1], 50)
y = linear_interp_with_extrapolation(observed_values['ci'], observed_values['an'], x)

# Graph predicted and measured values. 
fig = go.Figure(data=go.Scatter(x=model_values['ci'], y=model_values['an'], mode="markers+lines", name="Predicted"))
fig.add_trace(go.Scatter(x=x, y=y, name="Measured"))
# fig.add_trace(go.Scatter(x=observed_ci, y=observed_an, mode="markers", name="Measured"))
fig.update(
    layout_title_text="An vs. Ci",
    layout_xaxis_title_text="Ci",
    layout_yaxis_title_text="An",
)
fig.show()

## Defining a loss function

For our loss function, we want to determine how far apart the predicted and observed curves are from each other. 

To quantify this, we take a root mean squared error over 50 sample points for each function. 

In [7]:
def loss(cair_values, vcmax_z: float) -> float:
    model_values = get_model_values(cair_values, vcmax_z)
    observed_values = get_observed_values("2021-07-15-Me2-05.csv")
    observed_values = linear_interp_with_extrapolation(observed_values["ci"], observed_values["an"], model_values['ci'])
    return np.sum((model_values['an'] - observed_values) ** 2)


In [18]:
from functools import partial

sample_points = 100
cair_values = np.linspace(10, 100, sample_points)
cair_values = np.linspace(10, 100, sample_points)
vcmax_loss = partial(loss, cair_values)

Since our loss function is pretty simple (only one parameter), we can graph it to see roughly where minimum loss appears.

In [20]:
# Graph loss function for a range of Vcmax values
vcmax_values = np.linspace(35, 45, 30)
loss_values = np.zeros_like(vcmax_values)

for i, vcmax in enumerate(vcmax_values):
    loss_values[i] = vcmax_loss(vcmax)

fig = go.Figure(data=go.Scatter(x=vcmax_values, y=loss_values))
fig.show()


### Method 1: Uniform Sampling

First, we can try getting an optimal value of Vcmax by guessing 50 evenly spaced points. 

In [21]:
num_points = 50
vcmax_values = np.linspace(0, 100, num_points)
loss_table = np.zeros(len(vcmax_values))
for i, vcmax_z in enumerate(vcmax_values):
    loss_table[i] = vcmax_loss(vcmax_z)
    print(f"vcmax_z: {vcmax_z}, loss: {loss_table[i]}")


# Find the minimum loss (ignorning NaNs)
loss_table = np.ma.array(loss_table, mask=np.isnan(loss_table))
min_loss = np.min(loss_table)
min_loss_index = np.argmin(loss_table)
uniform_vcmax = vcmax_values[min_loss_index]

print(f"Best vcmax_z: {uniform_vcmax}, min_loss: {min_loss}")


vcmax_z: 0.0, loss: 15866.146484375
vcmax_z: 2.0408163265306123, loss: 13958.0615234375
vcmax_z: 4.081632653061225, loss: 12197.337890625
vcmax_z: 6.122448979591837, loss: 10580.466796875
vcmax_z: 8.16326530612245, loss: 9093.908203125
vcmax_z: 10.204081632653061, loss: 7703.740234375
vcmax_z: 12.244897959183675, loss: 6278.212890625
vcmax_z: 14.285714285714286, loss: 5047.794921875
vcmax_z: 16.3265306122449, loss: 4061.49755859375
vcmax_z: 18.367346938775512, loss: 3241.072021484375
vcmax_z: 20.408163265306122, loss: 2548.092041015625
vcmax_z: 22.448979591836736, loss: 1949.752197265625
vcmax_z: 24.48979591836735, loss: 1444.1715087890625
vcmax_z: 26.53061224489796, loss: 1024.386474609375
vcmax_z: 28.571428571428573, loss: 685.7720336914062
vcmax_z: 30.612244897959183, loss: 421.1419372558594
vcmax_z: 32.6530612244898, loss: 225.9942626953125
vcmax_z: 34.69387755102041, loss: 95.13872528076172
vcmax_z: 36.734693877551024, loss: 23.735992431640625
vcmax_z: 38.775510204081634, loss: 7.

### Method 2: Gradient Descent

Rather than randomly guessing parameter values, we can use gradient descent to find the value of Vcmax that minimizes the loss function. This should take less iterations to find a lower loss value. 

In [10]:
import jax
from functools import partial


grad_loss = jax.grad(vcmax_loss)

def gradient_descent(param, learning_rate, n_steps):
    for _ in range(n_steps):
        grads = grad_loss(param)
        param = param - learning_rate * grads
        print(f"param: {param}, loss: {vcmax_loss(param)}")
    return param

In [None]:
initial_vcmax = 30.0
gradient_vcmax = gradient_descent(initial_vcmax, 0.1, 10)

param: 42.22740119445518, loss: 90.23400270689132
param: 37.94320406120169, loss: 7.66977584121633
param: 38.51257723758147, loss: 6.493522362680697
param: 38.364562362527145, loss: 6.397030387708359
param: 38.4004026475763, loss: 6.392220936062135
param: 38.391457873291884, loss: 6.391920489292645
param: 38.39362389936614, loss: 6.391901901933405
param: 38.39315641591215, loss: 6.3919010832882295
param: 38.393273666898956, loss: 6.391901031790853
param: 38.3932442574269, loss: 6.391901028550744


In [None]:
print(f"Best vcmax: {gradient_vcmax}, loss: {vcmax_loss(gradient_vcmax)}")

Best vcmax: 38.3932442574269, loss: 6.391901028550744


### Graphing the results

In [None]:
cair_values = np.linspace(10, 100, 19)

observed_values = get_observed_values("2021-07-15-Me2-05.csv")

uniform_values = get_model_values(cair_values, uniform_vcmax)
gradient_values = get_model_values(cair_values, gradient_vcmax)

# For plotting, truncate observed results to the range from uniform sampling.
observed_x = np.linspace(uniform_values['ci'][0], uniform_values['ci'][-1], 50) 
observed_y = linear_interp_with_extrapolation(observed_values['ci'], observed_values['an'], observed_x)

# Graph predicted and measured values. 
fig = go.Figure(data=go.Scatter(x=observed_x, y=observed_y, name="Measured"))
fig.add_trace(go.Scatter(x=uniform_values['ci'], y=uniform_values['an'], name="Uniform Sampling"))
fig.add_trace(go.Scatter(x=gradient_values['ci'], y=gradient_values['an'], name="Gradient Descent"))
fig.update(
    layout_title_text="An vs. Ci",
    layout_xaxis_title_text="Ci",
    layout_yaxis_title_text="An",
)
fig.show()

## Observations

How do uniform sampling and gradient descent stack up? Surprisingly, uniform sampling performs better, achieving slighly lower loss than gradient descent. Moreover, the gradient descent version runs much slower than the uniform sampling version. It takes around 12 seconds to set up the functions, and around 15 seconds per iteration. Why is it so slow?

That said, the gradient descent version has some advantages:
1. Makes far fewer calls to the model
2. Makes no assumptions about the range of input parameters
3. Can run gradient descent across multiple parameters simultaneously

Also, the low loss of the uniform sampling method might be specific to this case. Its success depends on the shape of the loss curve. 