In [None]:
%pylab inline
import os
import shutil
from glob import glob
import pysumma as ps
import xarray as xr
import pandas as pd
from pathlib import Path

In [None]:
sites = os.listdir('../sites/')
bad_sites = []
sim_sites = [s for s in sites if s not in bad_sites]

seed = 50334
np.random.seed(seed)
np.random.shuffle(sim_sites)
len(sim_sites)

nfolds = 4
kfold_test_sites = np.array(sim_sites).reshape(nfolds, -1)
kfold_train_sites = np.vstack([
    list(set(sim_sites) - set(train_sites)) for train_sites in kfold_test_sites
])
model_src = [f'../models/train_states_set_{n}.txt' for n in range(nfolds)]
model_rel_dest = [f'../params/train_states_set_{n}.txt' for n in range(nfolds)]

executable = '../state_ml_summa/bin/ml_summa'

kfold_configs = []
for n in range(nfolds) :
    kfold_sites = kfold_test_sites[n]
    config = {site: {'file_manager': f'../sites/{site}/file_manager.txt'} for site in kfold_sites}
    
    for s, c in config.items():
        # set model in file manager
        fman = ps.FileManager(c['file_manager'], name='')
        fman.options.append(ps.file_manager.FileManagerOption('neuralNetFile', model_rel_dest[n]))
        fman['outFilePrefix'] = f'state_nn_{n}_output'
        
        # save file manager
        fman.write(path=str(fman.original_path).replace('file_manager', f'file_manager_{n}_state_NN'))
        
        # copy neural net file over to params
        nn_dest = c['file_manager'].replace('file_manager.txt', model_rel_dest[n].replace('../', ''))
        shutil.copy(model_src[n], nn_dest)
        c['file_manager'] = f'../sites/{s}/file_manager_{n}_state_NN.txt'
    kfold_configs.append(config)

config = kfold_configs[0]
for kfc in kfold_configs[1:]:
    config.update(kfc)

In [None]:
ens = ps.Ensemble(executable, config, num_workers=31)
ens.run('local')
summary = ens.summary()

In [None]:
summary

In [None]:
all_outfiles = {k: config[k]['file_manager']
                     .replace('file_manager_', f'output/state_nn_')
                     .replace('_NN.txt', f'_output_{k}_timestep.nc') 
                for config in kfold_configs for k in config }

In [None]:
fig, axes = plt.subplots(8, 8, figsize=(18,12), sharex=True, sharey='row')
axes = axes.flatten()

for i, site in enumerate(kfold_train_sites[0]):
    if i == len(axes): break
    print(site)
    try:
        nn = xr.open_dataset(all_outfiles[site])
        sim = xr.open_dataset(f'../sites/{site}/output/template_output_{site}_timestep.nc')
        obs = xr.open_dataset(f'../sites/{site}/forcings/{site}.nc')
    except:
        continue
    qle_nn = -nn['scalarLatHeatTotal'].load()
    qle_sim = -sim['scalarLatHeatTotal'].load()
    qle_obs = obs['Qle'].load()
    
    if 'Qle_cor' in obs:
        qle_corr = obs['Qle_cor']
        qle_corr.groupby(obs.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', linewidth=2, label='Observed')
    
    qle_nn.groupby(nn.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='Neural Net')
    qle_sim.groupby(sim.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='SUMMA')
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    sim.close()
    obs.close()
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Latent heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
axes[-1].legend()

In [None]:
fig, axes = plt.subplots(8, 8, figsize=(18,12), sharex=True, sharey='row')
axes = axes.flatten()

for i, site in enumerate(kfold_train_sites[0]):
    if i == len(axes): break
    print(site)
    try:
        nn = xr.open_dataset(f'../sites/{site}/output/nn_0_output_{site}_timestep.nc')
        sim = xr.open_dataset(f'../sites/{site}/output/template_output_{site}_timestep.nc')
        obs = xr.open_dataset(f'../sites/{site}/forcings/{site}.nc')
    except:
        continue
    qle_nn = -nn['scalarSenHeatTotal'].load()
    qle_sim = -sim['scalarSenHeatTotal'].load()
    qle_obs = obs['Qh'].load()
    
    
    if 'Qh_cor' in obs:
        qle_corr = obs['Qh_cor']
        qle_corr.groupby(obs.time.dt.month).mean(dim='time').plot(ax=axes[i], color='black', linewidth=2, label='Observed')
    
    qle_nn.groupby(nn.time.dt.month).mean(dim='time').plot(ax=axes[i], color='royalblue', label='Neural Net')
    qle_sim.groupby(sim.time.dt.month).mean(dim='time').plot(ax=axes[i], color='crimson', label='SUMMA')
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    sim.close()
    obs.close()
    
fig.text(0.5, -0.02, r'Month', ha='center', )
fig.text(-0.02, 0.5, r'Sensible heat $(W/m^2)$', va='center', rotation='vertical', )
plt.tight_layout(pad=0.1)
axes[-1].legend()

In [None]:
sim['scalarCanopyWat'].sel(time=slice('2000','2004')).plot()

In [None]:
#obs['Qle_cor'].sel(time='2004').plot()
(100 * (obs['SWRadAtm'] ** (1./3))).sel(time='2004').plot()
sim['scalarLatHeatGround'].sel(time='2004').plot()

In [None]:
print()

In [None]:
fig, axes = plt.subplots(8, 7, figsize=(40,25), sharex=True, sharey=True)
axes = axes.flatten()

for i, site in enumerate(sites):
    sim = xr.open_dataset(f'./sites/{site}/output/template_output_{site}_timestep.nc')
    obs = xr.open_dataset(f'./sites/{site}/forcings/{site}.nc')
    
    
    qh_sim = -sim['scalarLatHeatTotal'].load()
    qh_obs = obs['Qle'].load()
    
    if 'Qle_cor' in obs:
        qh_corr = obs['Qle_cor']
        qh_corr.groupby(obs.time.dt.hour).quantile(dim='time', q=0.5).plot(ax=axes[i], color='tomato', linewidth=2, label='Corrected')
    
    qh_obs.groupby(obs.time.dt.hour).quantile(dim='time', q=0.5).plot(ax=axes[i], color='black', linewidth=2, label='Observed')
    qh_sim.groupby(sim.time.dt.hour).quantile(dim='time', q=0.5).plot(ax=axes[i], color='slateblue', label='Simulated')
    axes[i].set_ylabel('')
    axes[i].set_xlabel('')
    axes[i].set_title(site)
    
    
plt.suptitle('Latent Heat Fluxes')
axes[0].legend()

In [None]:
site = 'FR-LBr'
sim = xr.open_dataset(f'./sites/{site}/output/template_output_output_{site}_timestep.nc')
obs = xr.open_dataset(f'./sites/{site}/forcings/{site}.nc')

In [None]:
doy = sim['scalarLatHeatTotal'].resample(time='1D').mean()
doy = doy.groupby(doy.time.dt.year).apply(lambda x: x)

In [None]:
for year in sorted(np.unique(doy.time))[1:-1]:
    plt.plot((-doy.sel(year=year).values), color='blue', alpha=0.5)

In [None]:
(-doy).plot()