In [4]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import jax
from jax import random
from jax import numpy as jnp
import optax

from nam import defaults
from nam import nam_excel

In [5]:
# Set up optimization experiment

n = 100
rngkey = random.key(181225)

trainable_params_keys = [k for k,x in defaults.default_params_space.items() if x[0] < x[1]]
trainable_params = {}
for k in trainable_params_keys:
    rngkey, subkey = random.split(rngkey)
    trainable_params[k] = random.uniform(
        subkey, shape=(n,),
        minval=defaults.default_params_space[k][0],
        maxval=defaults.default_params_space[k][1]
    )
fixed_params = {k:x[0] for k,x in defaults.default_params_space.items() if x[0] >= x[1]}

trainable_state_keys = [k for k,x in defaults.default_state_space.items() if x[0] < x[1]]
trainable_state = {}
for k in trainable_state_keys:
    rngkey, subkey = random.split(rngkey)
    trainable_state[k] = random.uniform(
        subkey, shape=(n,),
        minval=defaults.default_state_space[k][0],
        maxval=defaults.default_state_space[k][1]
    )
fixed_state = {k:x[0] for k,x in defaults.default_state_space.items() if x[0] >= x[1]}

In [None]:
# Load the data needed

In [7]:
defaults.default_observations

NAM_Observation(p=Array([0.03, 0.49, 0.05, ..., 4.73, 6.05, 6.35], dtype=float32), epot=Array([0.2967742 , 0.2967742 , 0.2967742 , ..., 0.30645162, 0.30645162,
       0.30645162], dtype=float32), t=Array([-5.2, -4.9, -5.2, ...,  1.3,  0.9, -1.7], dtype=float32))

In [4]:
data = pd.read_csv("excel_with_defaults.csv", delimiter=";", decimal=",")
data["date"] = pd.to_datetime(data["date"], dayfirst=True)

observations = NAM_Observation(**jax.tree.map(jnp.asarray, {"p": data["p"], "epot": data["epot"], "t": data["temp"]}))
targets = jnp.asarray(data["qobs"])

In [5]:
_, r = predict(params_all, state_all, observations)

In [6]:
jnp.mean(jnp.square(r - targets))

Array(38.237206, dtype=float32)

In [7]:
err = mse(
    params_fixed=params_fixed,
    params_trainable=params_trainable,
    state_fixed=state_fixed,
    state_trainable=state_trainable,
    obs=observations,
    target=targets
)

In [10]:
mse_grad = jax.grad(mse, argnums=[0,1])

In [12]:
mse_grad(
    params_trainable,
    state_trainable,
    params_fixed,
    state_fixed,
    obs=observations,
    target=targets
)

({'ck1': Array(-26.358887, dtype=float32),
  'ckbf': Array(1.0746789, dtype=float32),
  'ckif': Array(1.9791752, dtype=float32, weak_type=True),
  'cqof': Array(56.51227, dtype=float32, weak_type=True),
  'l_max': Array(-0.11636028, dtype=float32, weak_type=True),
  'tg': Array(4.847694, dtype=float32, weak_type=True),
  'tif': Array(-0.17708704, dtype=float32, weak_type=True),
  'tof': Array(-1.6521508, dtype=float32, weak_type=True),
  'u_max': Array(-0.16479407, dtype=float32, weak_type=True)},
 {'bf': Array(0.05019228, dtype=float32),
  'l_ratio': Array(2.1993843e-07, dtype=float32),
  'qr1': Array(-0.00294913, dtype=float32),
  'u_ratio': Array(0.00048687, dtype=float32, weak_type=True)})

In [26]:
optimizer = optax.adam(learning_rate=1e-2)
copy_par = params_trainable.copy()
opt_state = optimizer.init(copy_par)

In [38]:
for _ in range(20):
  grads = jax.grad(mse, argnums=0)(copy_par, state_trainable,params_fixed,state_fixed,obs=observations,target=targets)
  updates, opt_state = optimizer.update(grads, opt_state)
  copy_par = optax.apply_updates(copy_par, updates)

In [29]:
mse(params_trainable, state_trainable, params_fixed, state_fixed, observations, targets)

Array(38.237206, dtype=float32)

In [28]:
mse(copy_par, state_trainable, params_fixed, state_fixed, observations, targets)

Array(30.062603, dtype=float32)

In [34]:
mse(copy_par, state_trainable, params_fixed, state_fixed, observations, targets)

Array(24.01263, dtype=float32)

In [36]:
mse(copy_par, state_trainable, params_fixed, state_fixed, observations, targets)

Array(19.793165, dtype=float32)

In [39]:
mse(copy_par, state_trainable, params_fixed, state_fixed, observations, targets)

Array(15.109293, dtype=float32)

In [40]:
to_physical(copy_par), to_physical(params_trainable)

({'ck1': Array(0.702743, dtype=float32),
  'ckbf': Array(0.99707854, dtype=float32),
  'ckif': Array(0.03477174, dtype=float32),
  'cqof': Array(0.22045451, dtype=float32),
  'l_max': Array(100.43097, dtype=float32),
  'tg': Array(0.39887768, dtype=float32),
  'tif': Array(0.60943526, dtype=float32),
  'tof': Array(0.27950484, dtype=float32),
  'u_max': Array(5.469987, dtype=float32)},
 {'l_max': Array(100., dtype=float32, weak_type=True),
  'u_max': Array(5., dtype=float32, weak_type=True),
  'cqof': Array(0.3, dtype=float32, weak_type=True),
  'ckif': Array(0.05, dtype=float32, weak_type=True),
  'tif': Array(0.5, dtype=float32, weak_type=True),
  'tof': Array(0.2, dtype=float32, weak_type=True),
  'tg': Array(0.5, dtype=float32, weak_type=True),
  'ck1': Array(0.60653067, dtype=float32),
  'ckbf': Array(0.998002, dtype=float32)})

In [23]:
copy_par, params_trainable

({'ck1': Array(0.43375212, dtype=float32),
  'ckbf': Array(6.212605, dtype=float32),
  'ckif': Array(-2.9454389, dtype=float32),
  'cqof': Array(-0.84829783, dtype=float32),
  'l_max': Array(100.001, dtype=float32),
  'tg': Array(-0.00099999, dtype=float32),
  'tif': Array(0.00099999, dtype=float32),
  'tof': Array(-1.3852943, dtype=float32),
  'u_max': Array(4.9942393, dtype=float32)},
 {'l_max': Array(100., dtype=float32, weak_type=True),
  'u_max': Array(4.9932394, dtype=float32, weak_type=True),
  'cqof': Array(-0.84729785, dtype=float32, weak_type=True),
  'ckif': Array(-2.944439, dtype=float32, weak_type=True),
  'tif': Array(0., dtype=float32, weak_type=True),
  'tof': Array(-1.3862944, dtype=float32, weak_type=True),
  'tg': Array(0., dtype=float32, weak_type=True),
  'ck1': Array(0.43275213, dtype=float32),
  'ckbf': Array(6.213605, dtype=float32)})

In [13]:
params_trainable

{'l_max': Array(100., dtype=float32, weak_type=True),
 'u_max': Array(4.9932394, dtype=float32, weak_type=True),
 'cqof': Array(-0.84729785, dtype=float32, weak_type=True),
 'ckif': Array(-2.944439, dtype=float32, weak_type=True),
 'tif': Array(0., dtype=float32, weak_type=True),
 'tof': Array(-1.3862944, dtype=float32, weak_type=True),
 'tg': Array(0., dtype=float32, weak_type=True),
 'ck1': Array(0.43275213, dtype=float32),
 'ckbf': Array(6.213605, dtype=float32)}

In [None]:
errgrad = jax.grad(mse, )(**model_config, q=qobs)

TypeError: differentiating with respect to argnums=(0, 1, 2, 3) requires at least 4 positional arguments to be passed by the caller, but got only 0 positional arguments.

In [10]:
help(mse)

Help on function mse in module nam:

mse(u_max: float, l_max: float, cqof: float, ckif: float, tof: float, tif: float, tg: float, ck1: float, ck2: float, ckbf: float, c_area: float, c_snow: float, s: float, u: float, l: float, qr1: float, qr2: float, bf: float, p: jax.Array, epot: jax.Array, t: jax.Array, q: jax.Array)



In [8]:
err.shape

()

In [None]:
def mse(params, x, y):
    next_state, sim = step(obs, state, params)
    return (sim-target)**2

mse_grad = jax.grad(mse, argnums=3)

In [9]:
initial_params

NAM_Parameters(area=1055.0, c_area=0.9, c_snow=2.0, l_max=100.0, u_max=5.0, cqof=0.3, ckif=20.0, tof=0.2, tif=0.5, tg=0.5, ck1=np.float64(0.6065306597126334), ck2=0.0, ckbf=np.float64(0.9980019986673331))

In [6]:
mse(14, observations[0], initial_state, initial_params)

np.float64(0.7498554846293193)

In [10]:
mse_grad(14, observations[0], initial_state, initial_params)

NAM_Parameters(area=Array(0.0244039, dtype=float32, weak_type=True), c_area=Array(0., dtype=float32, weak_type=True), c_snow=Array(0., dtype=float32, weak_type=True), l_max=Array(-0.04160437, dtype=float32, weak_type=True), u_max=Array(0., dtype=float32, weak_type=True), cqof=Array(0., dtype=float32, weak_type=True), ckif=Array(-0.10401091, dtype=float32, weak_type=True), tof=Array(0., dtype=float32, weak_type=True), tif=Array(0., dtype=float32, weak_type=True), tg=Array(0., dtype=float32, weak_type=True), ck1=Array(3.8065414, dtype=float32), ck2=Array(-7.5956464, dtype=float32, weak_type=True), ckbf=Array(18.186808, dtype=float32))

In [7]:
step(observations[0], initial_state, initial_params)

(NAM_State(s=0.03, u=4.453, l=100.0, qr1=np.float64(0.359175518748274), qr2=np.float64(0.359175518748274), bf=np.float64(0.8582817188539065)),
 np.float64(14.865941963776626))

In [7]:
step(observations[0], initial_state, initial_params)

(NAM_State(s=0.03, u=4.617100000000001, l=100.0, qr1=0.7740999999999999, qr2=0.7740999999999999, bf=601.96039),
 635884.88695)

In [6]:
initial_state

NAM_State(s=0, u=5, l=100, qr1=0.43, qr2=0, bf=0.86)

In [7]:
635884.88695 / 86.4

7359.778784143517