In [8]:
import pandas as pd
from pathlib import Path

from hydrosystem.discharge.nam import NAMParameters
from hydrosystem.discharge.nam.observation import NAMObservation, NAMTarget
from jax import numpy as jnp
import jax
import numpy as np

path_to_processed_data = Path("../src/data/processed")
import json

from hydrosystem.discharge.nam.parameters import NAMParameters
from hydrosystem.discharge import nam

In [58]:
ts_deltachange = pd.read_csv(path_to_processed_data / "timeseries_future_deltachange.csv")
ts_deltachange["date"] = pd.to_datetime(pd.to_datetime(ts_deltachange["date"]).dt.date)
ts_deltachange = ts_deltachange[np.logical_and(
    ts_deltachange["date"].dt.year >= 2069,
    ts_deltachange["date"].dt.year <= 2098,
)]
ts_deltachange["model"] = ts_deltachange["model"].str.lower()
ts_deltachange["bias_correction"] = "deltachange"

ts_dbs = pd.read_csv(path_to_processed_data / "timeseries_future_dbs.csv")
ts_dbs["date"] = pd.to_datetime(pd.to_datetime(ts_dbs["date"]).dt.date)
ts_dbs = ts_dbs[np.logical_and(
    ts_dbs["date"].dt.year >= 2069,
    ts_dbs["date"].dt.year <= 2098,
)]
ts_dbs["model"] = ts_dbs["model"].str.lower()
ts_dbs["bias_correction"] = "dbs"

observations = {
    k: {
        model: NAMObservation(
            p=jnp.asarray(df[df["model"]==model]["p"]),
            epot=jnp.asarray(df[df["model"]==model]["epot"]),
            t=jnp.asarray(df[df["model"]==model]["t"]),
        ) for model in df["model"].unique()
        if df[df["model"]==model]["p"].shape==(10957,)
    }
    for k,df in zip(["deltachange", "dbs"],[ts_deltachange, ts_dbs])
}

with open("optimized_params_penalty.json", "r") as json_file:
    params = NAMParameters(**json.load(json_file))

In [59]:
def work(path, obs):
    _, pred = nam.predict(params, obs)

    df = pd.DataFrame({
        "date": ts_deltachange[ts_deltachange["model"]=="arpege-hirham5"]["date"].reset_index(drop=True),
        "bias_correction": path[0].key,
        "model": path[1].key,
        "q": pred.q,
        "eact": pred.eact,
        "perc": pred.perc,
        "recharge": pred.recharge,
        "storage": pred.storage
    })

    return df

predictions = jax.tree.map_with_path(work, observations, is_leaf=lambda x: isinstance(x, NAMObservation))

In [64]:
preds = []
for k in predictions:
    for kk in predictions[k]:
        preds.append(predictions[k][kk])
preds = pd.concat(preds, axis=0, ignore_index=True)
ts_concat = pd.concat([ts_deltachange, ts_dbs], ignore_index=True)
preds_merged = preds.merge(
    ts_concat,
    on=["date", "model", "bias_correction"],
    how="left"
)
preds_merged.to_csv(path_to_processed_data / "future_projections.csv", index=False)