In [None]:
import sys
sys.path.append('..')  # Add the parent folder to the sys.path
sys.path.append('../..')  # Add the parent of the parent folder to the sys.path
from pathlib import Path
from neuralhydrology.nh_run import start_run

In [None]:
start_run(config_file=Path("example4.yml", gpu=-1))

In [None]:
run_dir = "runs/dev/arch1" #CHANGE IT

import pickle
import matplotlib.pyplot as plt
from neuralhydrology.evaluation import metrics
from math import sqrt

def mean(l):
    return sum(l)/len(l)

def stdev(l):
    m = mean(l)
    return sqrt(sum([(i-m)**2 for i in l])/len(l))

basins = ['01350080','01350140']
fig, axs = plt.subplots(2,1,figsize=(14, 8)) 

for nb_epochs in range(5,26,5):
    nb_epochs = 25
    nb_epochs = str(nb_epochs)
    nb_epochs = (3-len(nb_epochs))*'0'+nb_epochs

    metric_values = []
    with open(run_dir + "/validation/model_epoch"+nb_epochs+"/validation_results.p", "rb") as fp:
        results = pickle.load(fp)
        for i,basin in enumerate(basins):
            qsim = results[basin]['1D']['xr']['QObs(mm/d)_sim']
            qobs = results[basin]['1D']['xr']['QObs(mm/d)_obs']
            metric_values.append(metrics.calculate_all_metrics(qobs.isel(time_step=-1),
                                                qsim.isel(time_step=-1)))
            axs[i%2].plot(qobs, label='observed')
            axs[i%2].plot(qsim,label="epoch "+nb_epochs, alpha = 0.5)
            axs[i%2].set_title("basin "+basin)
plt.legend()
plt.show()

for key in metric_values[0]:
    vals = [i[key] for i in metric_values]
    print(f"{key}: {mean(vals):.3f} ({stdev(vals):.3f})")

In [None]:
import pickle
from math import exp

class TR:

    def __init__(self, b, smax,c, s0=0) -> None:
        self.b = b
        self.smax = smax
        self.s = s0
        self.c = c
    
    def __call__(self, p,ep):
        return self.forward(p,ep)

    def forward(self,p,ep):
        ep = round(ep*(ep>0), 4)
        q,e = 0.0, ep
        smax, b, s,c = self.smax, self.b, self.s, self.c
        sbar = s/smax
        me = p-ep
        a = me/smax
        if me >= 0:
            q = me + s - smax
            temp = (1-sbar)**(1/(b+1)) - me/((b+1)*smax)
            if temp > 0:
                temp = smax*(temp**(b+1))
                q += temp
                temp = round(temp,4)
                q = round(q,4)
                self.s = smax - temp
            else:
                q = round(q,4)
                self.s = smax
                
        else:
            e = p + s
            temp = a*(1-c)+sbar**(1-c)
            if temp > 0:
                e-=smax*temp**(1/(1-c))
            e = round(e,4)
            self.s += p-e
        self.s = round(self.s,4)
        assert e >= 0
        assert self.s >= 0
        assert self.s <= smax
        assert q >= 0
        return q,e

class LR:
    def __init__(self,k, s0=0) -> None:
        self.k = k
        self.s = s0

    def __call__(self, p):
        return self.forward(p)
    
    def forward(self,p):
        q = p + self.s - (p*exp(self.k)+self.k*self.s-p)/(self.k*exp(self.k))
        self.s += p-q
        assert q >= 0
        assert self.s >= 0
        return q

def generate_data(name,b,smax,c,k1,k2):
    pkl_path = '../../data/pet.pkl'
    with open(pkl_path, 'rb') as f:
        pets = pickle.load(f)[name]['PET(mm/d)'].tolist()

    file_path = "../../data/fake/basin_mean_forcing/daymet/02/"+name+"_lump_cida_forcing_leap.txt"
    precip = []
    dates = []
    with open(file_path) as f:
        for _ in range(5):
            l = f.readline()
        while l:
            l = l.split('\t')
            precip.append(float(l[2]))
            dates.append(l[0])
            l = f.readline()
    tr = TR(b,smax,c)
    lr1 = LR(k1)
    lr2 = LR(k2)
    with open("../../data/fake/usgs_streamflow/02/"+name+"_streamflow_qc.txt",'w') as f:
        for d,p,e in zip(dates, precip,pets):
            output_tr = tr(p,e)[0]
            output = lr1(0.7*output_tr)+lr2(0.3*output_tr)
            f.write(' '.join([name, d[:-3], str(output), 'A']))
            f.write('\n')

generate_data('01333000',0.7,50.0,0.5,2.0,0.1)
generate_data('01350000',2.0,20.0,0.9,2.0,0.5)

In [None]:
nb_archs = 5
ms = []
stds = []
for i in range(1,nb_archs+1):
    ms.append([])
    stds.append([])
    run_dir = "runs/dev/arch"+str(i)
    metric_values = []
    with open(run_dir + "/validation/model_epoch"+nb_epochs+"/validation_results.p", "rb") as fp:
        results = pickle.load(fp)
        for i,basin in enumerate(basins):
            qsim = results[basin]['1D']['xr']['QObs(mm/d)_sim']
            qobs = results[basin]['1D']['xr']['QObs(mm/d)_obs']
            metric_values.append(metrics.calculate_all_metrics(qobs.isel(time_step=-1),
                                                qsim.isel(time_step=-1)))

    for key in metric_values[0]:
        vals = [i[key] for i in metric_values]
        ms[-1].append(round(mean(vals),3))
        stds[-1].append(round(stdev(vals),3))

res = "\\textbf{}" + "".join([" & \\textbf{Architecture "+str(i)+"}" for i in range(1,nb_archs+1)]) + " \\\\ \hline \n"
for i,k in enumerate(metric_values[0].keys()):
    res += k
    for a in range(len(ms)):
        res += ' & '+str(ms[a][i]) #+ '(' + str(stds[a][i])+')'
    res += "\\\\ \hline \n"

print(res)