In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
#import pysumma as ps
import xarray as xr
import pandas as pd
from pathlib import Path
from joblib import Parallel, delayed
import pysumma as ps
import pandas as pd
from tqdm.notebook import tqdm
import seaborn as sns
from glob import glob
import shutil

sns.set_context('talk')
mpl.style.use('seaborn-bright')
mpl.rcParams['figure.figsize'] = (18, 12)

veg_igbp = pd.read_csv('./VEGPARM_IGBP_MODIS_NOAH.TBL', 
                       index_col=-1, skipinitialspace=True)
veg_igbp.index = veg_igbp.index.map(lambda x: x.strip().replace("'", ""))

soil_rosetta = pd.read_csv('./SOILPARM_ROSETTA.TBL', 
                           index_col=-1, skipinitialspace=True)
soil_rosetta.index = soil_rosetta.index.map(lambda x: x.strip().replace("'", ""))

In [None]:
max_iter = 500

def run_site_calib(site, max_iter=max_iter):
    summa_exe = '/pool0/data/andrbenn/ml_summa/summa/bin/summa.exe'
    ostrich_exe = '/pool0/home/andrbenn/data/naoki_calib_example/ostrich.exe'
    python_exe = '/pool0/data/andrbenn/.conda/all/bin/python'
    ostrich = ps.Ostrich(ostrich_exe, summa_exe, f'../sites/{site}/file_manager.txt', python_path=python_exe)
    
    ostrich.max_iters = max_iter
    ostrich.allow_failures = True
    ostrich.perturb_val = 0.2
    attr = ostrich.simulation.local_attributes
    veg_igbp = pd.read_csv('./VEGPARM_IGBP_MODIS_NOAH.TBL', 
                       index_col=-1, skipinitialspace=True)
    veg_igbp.index = veg_igbp.index.map(lambda x: x.strip().replace("'", ""))
    
    soil_rosetta = pd.read_csv('./SOILPARM_ROSETTA.TBL', 
                               index_col=-1, skipinitialspace=True)
    soil_rosetta.index = soil_rosetta.index.map(lambda x: x.strip().replace("'", ""))
    soil_params = soil_rosetta[soil_rosetta['SOILTYPINDEX'] == attr['soilTypeIndex'].values[0]]
    veg_params = veg_igbp[veg_igbp['VEGTYPINDEX'] == attr['vegTypeIndex'].values[0]]
    
    # Source: Zeng 2001 AMS
    igbp_rooting_depths = {1: 1.8,  2: 3.0,  3: 2.0,   4: 2.0,  5: 2.4,  6: 2.5,  7: 3.10,  8: 1.7,
                           9: 2.4, 10: 1.5, 11: 0.02, 12: 1.5, 13: 1.5, 14: 1.5, 15: 0.01, 16: 4.0}
    
    initial_values = {
        'rootingDepth': igbp_rooting_depths[attr['vegTypeIndex'].values[0]],
        'theta_res': soil_params['theta_res'].values[0],
        'theta_sat': soil_params['theta_sat'].values[0],
    }
    
    param_ranges = {
        'rootingDepth': initial_values['rootingDepth'] * np.array([0.5, 1.5]),
    }
    
    ostrich.obs_data_file = f'/pool0/data/andrbenn/ml_summa/sites/{site}/forcings/{site}.nc'
    ostrich.sim_calib_vars = ['scalarLatHeatTotal', 'scalarSenHeatTotal']
    ostrich.obs_calib_vars = ['Qle_cor', 'Qh_cor']
    ostrich.import_strings = 'import numpy as np'
    ostrich.conversion_function = lambda x: -x
    ostrich.filter_function = lambda x,y : (
            x.isel(hru=0, gru=0, time=np.argwhere(~y['gap_filled'].isel(hru=0, drop=True).astype(bool).values).flatten()).isel(time=slice(48, None)),
            y.isel(hru=0, time=np.argwhere(~y['gap_filled'].isel(hru=0, drop=True).astype(bool).values).flatten()).isel(time=slice(48, None))
        )
    ostrich.cost_function = 'MSE'
    ostrich.maximize = False
    
    ostrich.calib_params = [
        ps.OstrichParam('vcmax_Kn', 0.6, (0.1, 1.2)),
        ps.OstrichParam('laiScaleParam', 1.0, (0.5, 3.0)),
        ps.OstrichParam('rootingDepth', initial_values['rootingDepth'], param_ranges['rootingDepth']),
        ps.OstrichParam('canopyWettingFactor', 0.7, (0.01, 0.9)),
        ps.OstrichParam('kAnisotropic', 1.0, (0.5, 5.0)),
        ps.OstrichParam('theta_res', initial_values['theta_res'],   (0.001,  0.2)),
        ps.OstrichParam('theta_sat', initial_values['theta_sat'],   (0.31,   0.7)),
    ]
    ostrich.add_tied_param('fieldCapacity', lower_bound='theta_res', upper_bound='theta_sat')
    ostrich.add_tied_param('critSoilTranspire', lower_bound='theta_res', upper_bound='theta_sat')
    ostrich.add_tied_param('critSoilWilting', lower_bound='theta_res', upper_bound='critSoilTranspire')
    
    start = pd.to_datetime(ostrich.simulation.manager['simStartTime'].value) 
    stop = pd.to_datetime(ostrich.simulation.manager['simEndTime'].value)
    ostrich.simulation.manager['simEndTime'] = str(start + pd.Timedelta('366D'))
    
    ostrich.write_config()
    ostrich.run(monitor=False)
    return ostrich

In [None]:
sites = os.listdir('../sites')

bad_sites = []
sites = [site for site in sites if site not in bad_sites]
config = {site: {'file_manager': f'../sites/{site}/file_manager.txt'} for site in sites if site not in bad_sites}
print(len(config))

summa_exe = '/pool0/data/andrbenn/ml_summa/summa/bin/summa.exe'
ostrich_exe = '/pool0/home/andrbenn/data/naoki_calib_example/ostrich.exe'
python_exe = '/pool0/data/andrbenn/.conda/all/bin/python'

In [None]:
calibrating = []
calibrated = []
n_workers = 4
for site in tqdm(sites):
    calibrating.append(run_site_calib(site))
    if len(calibrating) == n_workers:
        current = calibrating.pop(0)
        current.monitor()
        calibrated.append(current)
        print(current.config_path)
        
for c in calibrating:
    c.monitor()
    calibrated.append(c)

In [None]:
src_params = glob('../**/best_calibration/params/parameter_trial.nc', recursive=True)
dest_params = [p.replace('best_calibration/', '') for p in src_params]
_ = [shutil.copy(sp, dp) for sp, dp in zip(src_params, dest_params)]

In [None]:
def open_stats(o):
    file = str(o.config_path) + '/OstModel0.txt'
    site = str(o.config_path).split('/')[-2]
    if os.path.exists(file):
        df = pd.read_csv(file, delim_whitespace=True)
        if len(df) <= o.max_iters:
            return site, None
        else:
            return site, df
    else:
        return site, None

In [None]:
def open_metrics(o):
    file = str(o.config_path) + '/metrics_log.csv'
    site = str(o.config_path).split('/')[-2]
    if os.path.exists(file):
        df = pd.read_csv(file, names=['kge', 'mae', 'mse', 'rmse', 'nse'])
        if len(df) < o.max_iters:
            return site, None
        else:
            return site, df
    else:
        return site, None

In [None]:
all_df = [open_stats(o) for o in calibrated]
calib_sites = [a[0] for a in all_df if a[1] is not None]
bad_sites = [a[0] for a in all_df if a[1] is None]
all_df = [a[1][-501:] for a in all_df if a[1] is not None]
all_df = pd.concat(all_df)
new_idx = [(max_iter + 1)* [s] for s in calib_sites]
all_df.index = pd.Index(np.hstack(new_idx))

In [None]:
all_df['obj.function'][all_df['obj.function'] == 999999.0] = np.nan

In [None]:
bad_sites

In [None]:
calibrated_sites = np.unique(all_df.index)

In [None]:
plt.plot(np.sqrt(sorted(all_df.groupby(all_df.index)['obj.function'].min())), label='best')
plt.plot(np.sqrt(sorted(all_df.groupby(all_df.index)['obj.function'].mean())), label='average of all runs')
plt.plot(np.sqrt(sorted(all_df.groupby(all_df.index)['obj.function'].first())), label='initial parameters')
#plt.ylim([0, 1])
#plt.axhline(1.0, color='black')
#plt.axhline(0.5, color='black')
plt.axhline(0.0, color='black')
plt.axhline(60.0, color='black')
plt.axhline(100.0, color='black')
plt.xlabel('Site number')
plt.ylabel('RMSE ')
plt.ylim([-10, 160])
plt.legend()

In [None]:
np.sqrt(all_df['obj.function']).groupby(all_df['Run']).mean().plot(label='mean across all sites', color='darkgrey')
np.sqrt(all_df['obj.function']).rolling(15, center=True, min_periods=1).mean().groupby(all_df['Run']).mean()[10:-10].plot(label='15 run running mean', linewidth=3, color='royalblue')
np.sqrt(all_df['obj.function']).rolling(15, center=True, min_periods=1).min().groupby(all_df['Run']).mean()[10:-10].plot(label='15 run running minimum', linewidth=3, color='crimson')
plt.ylabel(r'RMSE $(W/m^2)$')
plt.legend()
plt.xscale('linear')


In [None]:
metric_df = [open_metrics(o) for o in calibrated]
calib_sites = [a[0] for a in metric_df if a[1] is not None]
metric_df = [a[1] for a in metric_df if a[1] is not None]

best_kge = []
first_kge = []
for df in metric_df:
    best_kge.append(df['nse'].max())
    first_kge.append(df['nse'].values[0])

plt.plot(sorted(best_kge)[::-1], 'o-')
plt.axhline(1.0, color='grey', linestyle=':')
plt.axhline(0.8, color='grey', linestyle=':')
plt.axhline(0.6, color='grey', linestyle=':')
plt.axhline(0.4, color='grey', linestyle=':')
plt.axhline(0.2, color='grey', linestyle=':')
plt.axhline(0.0, color='grey', linestyle=':')
plt.xlabel('Site number')
plt.ylabel('Best NSE')
plt.ylim([0, 1])