In [None]:
"""
Created on Fri Jan 19 14:30 2024

Apply some weights when comapring to training runs 

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import multimelt.useful_functions as uf
import os

In [None]:
sns.set_context('paper')

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
home_path='/bettik/burgardc/'

In [None]:
param_classic_list = ['linear_local',
              'quadratic_local','quadratic_local_locslope',
              'lazero19',
              'boxes_4_pismyes_picopno']

param_NN_list = ['NNsmall']

In [None]:
## Melt outputpath
Gt_allmod_list = []
box1_allmod_list = []

for nemo_run in ['OPM006','OPM016','OPM018','OPM021','OPM031','ctrl94','isf94','isfru94']: #'CNRM-CM6-1',
    
    outputpath_melt = '/bettik/burgardc/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CHECK_TUNING/nemo_5km_'+nemo_run+'/'

    melt1D_list = []
    for mparam in param_classic_list:
        melt1D_scenario = xr.open_dataset(outputpath_melt+'eval_metrics_1D_'+mparam+'_newtuning_oneFRIS.nc')
        melt1D_list.append(melt1D_scenario.assign_coords({'param':mparam}))
    melt1D_classic = xr.concat(melt1D_list, dim='param')       
    Gt_classic = melt1D_classic['melt_1D_Gt_per_y']

    
    if nemo_run in ['ctrl94','isf94','isfru94']:
        inputpath_profiles = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
        outputpath_melt = '/bettik/burgardc/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_TUNING/nemo_5km_'+nemo_run+'/'
    else:
        inputpath_profiles = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
        outputpath_melt = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'
        
    melt1D_list = []
    for mparam in param_NN_list:
        melt1D_scenario = xr.open_dataset(outputpath_melt + 'evalmetrics_1D_small_newbasic2_extrap_normstd_newtuning.nc')
        melt1D_list.append(melt1D_scenario.assign_coords({'param':mparam}))
    melt1D_NN = xr.concat(melt1D_list, dim='param')   
    Gt_NN = melt1D_NN['predicted_melt'].sel(metrics='Gt')

    Gt_all = xr.concat([Gt_classic, Gt_NN], dim='param')
    
    Gt_allmod_list.append(Gt_all.assign_coords({'nemo_run': nemo_run}))

Gt_allmod = xr.concat(Gt_allmod_list, dim='nemo_run')


In [None]:
ref_Gt_list = []
ref_box1_list = []


run_list = ['OPM031','OPM021','OPM018','OPM016','OPM006','ctrl94','isf94','isfru94'] #'OPM026',
for n,nemo_run in enumerate(run_list):
    if nemo_run in ['ctrl94','isf94','isfru94']:
        outputpath_melt = home_path+'DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_TUNING/nemo_5km_'+nemo_run+'/'
    else:
        outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'
    NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO_oneFRIS.nc')
    ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']
    ref_Gt_list.append(ref_Gt.assign_coords({'nemo_run': nemo_run}))
    NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO_oneFRIS.nc')
    ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']
    ref_box1_list.append(ref_box1.assign_coords({'nemo_run': nemo_run}))
    
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')

In [None]:
Gt_param_stacked_list = []
Gt_ref_stacked_list = []

for n,nrun in enumerate(Gt_allmod.nemo_run):
    
    Gt_param_stacked_list.append(Gt_allmod.sel(nemo_run=nrun).assign_coords({'time': Gt_allmod.time + n*150}))
    Gt_ref_stacked_list.append(ref_Gt_all.sel(nemo_run=nrun).assign_coords({'time': ref_Gt_all.time + n*150}))

In [None]:
Gt_param_stacked = xr.concat(Gt_param_stacked_list, dim='time')
Gt_ref_stacked = xr.concat(Gt_ref_stacked_list, dim='time')

In [None]:
Gt_param_stacked_clean = Gt_param_stacked.where(np.isfinite(Gt_param_stacked), drop=True)
Gt_ref_stacked_clean =  Gt_ref_stacked.where(np.isfinite(Gt_ref_stacked), drop=True)

In [None]:
Gt_param_stacked_clean['time'] = np.arange(len(Gt_param_stacked_clean.time))
Gt_ref_stacked_clean['time'] = np.arange(len(Gt_ref_stacked_clean.time))

In [None]:
Gt_param_stacked_clean.isel(param=0,Nisf=0).plot()
Gt_ref_stacked_clean.isel(Nisf=0).plot()

In [None]:
diff_mod_obs = (Gt_param_stacked_clean - Gt_ref_stacked_clean).mean('time')
sigma_obs = 100
sigma_mod = 100

In [None]:
s_j = np.exp(-((diff_mod_obs)**2/(sigma_obs**2 + sigma_mod**2)))

weight = (s_j / (s_j.sum(['param'])))

weight_clean = weight#.where(np.isfinite(weight),0)
#weight_clean.to_dataset(name='bay_weights').to_netcdf(outputpath_weights + 'bayesian_weights_davison.nc')

In [None]:
for kisf in weight_clean.Nisf:
    print(kisf.values, weight_clean.param.where(weight_clean.sel(Nisf=kisf) == weight_clean.sel(Nisf=kisf).max(), drop=True).values)

In [None]:
for kisf in weight_clean.Nisf:
    print(kisf.values, weight_clean.param.where(weight_clean.sel(Nisf=kisf) == weight_clean.sel(Nisf=kisf).min(), drop=True).values)