In [27]:
import pandas as pd
from pathlib import Path

from hydrosystem.discharge.nam import NAMParameters
from hydrosystem.discharge.nam.observation import NAMObservation, NAMTarget
from jax import numpy as jnp
import jax

path_to_processed_data = Path("../src/data/processed")
import json

from hydrosystem.discharge.nam.parameters import NAMParameters
from hydrosystem.discharge import nam

In [18]:
ts_deltachange = pd.read_csv(path_to_processed_data / "timeseries_future_deltachange.csv")
ts_deltachange["date"] = pd.to_datetime(ts_deltachange["date"])

ts_dbs = pd.read_csv(path_to_processed_data / "timeseries_future_dbs.csv")
ts_dbs["date"] = pd.to_datetime(ts_dbs["date"])

observations = {
    k: {
        model: NAMObservation(
            p=jnp.asarray(df[df["model"]==model]["p"]),
            epot=jnp.asarray(df[df["model"]==model]["epot"]),
            t=jnp.asarray(df[df["model"]==model]["t"]),
        ) for model in df["model"].unique()
    }
    for k,df in zip(["deltachange", "dbs"],[ts_deltachange, ts_dbs])
}

with open("optimized_params_penalty.json", "r") as json_file:
    params = NAMParameters(**json.load(json_file))

In [35]:
def work(path, obs):
    _, pred = nam.predict(params, obs)

    df = pd.DataFrame({
        "bias_corrrection": path[0].key,
        "model": path[1].key,
        "q": pred.q,
        "eact": pred.eact,
        "perc": pred.perc,
        "recharge": pred.recharge,
        "storage": pred.storage
    })

    return df

predictions = jax.tree.map_with_path(work, observations, is_leaf=lambda x: isinstance(x, NAMObservation))

In [40]:
preds = []
for k in predictions:
    for kk in predictions[k]:
        preds.append(predictions[k][kk])
preds = pd.concat(preds, axis=0, ignore_index=True)
preds.to_csv(path_to_processed_data / "climate_projections_q.csv", index=False)