In [None]:
"""
Created on Wed Apr 20 10:58 2022

Make a matrix with importance of the different variables after shuffling when applying on Smith

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
mod_size =  'xsmall96' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2'#'onlyTSdraftandslope' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
home_path = '/bettik/burgardc/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/input_vars/'


In [None]:
inputpath_mask = '/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_bf663/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_1980.nc').drop('time')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
var_list = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
shuffle_list = []
shuffle_Gt_list = []
shuffle_box1_list = []
res_1D_mods_list = []
res_1D_all_list = []

for nemo_run in ['bf663']: #, 'bi646'
    outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'/'
    for yy in range(1980, 1980 + 60):
            res_1D_yy = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_1D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
            res_1D_all_list.append(res_1D_yy.assign_coords({'time': yy}))    
    res_1D_all_xr = xr.concat(res_1D_all_list, dim='time')    
    res_1D_mods_list.append(res_1D_all_xr.assign_coords({'nemo_run': nemo_run}))
    for vv in var_list:
        res_1D_allyy = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_shuffled'+vv+'_1D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_allyy_'+nemo_run+'.nc')
        shuffle_list.append(res_1D_allyy.assign_coords({'shuffled_var': vv}))

    shuffle_allvars = xr.concat(shuffle_list, dim='shuffled_var').assign_coords({'nemo_run':nemo_run})

shuffle_all = xr.concat(shuffle_allvars, dim='nemo_run')
orig_all =  xr.concat(res_1D_mods_list, dim='nemo_run')       
        
#shuffle_box1_list.append(res_1D_allyy.sel(metrics='box1').assign_coords({'shuffled_var': vv}))


In [None]:
diff_Gt_orig = res_1D_orig['predicted_melt'].sel(metrics='Gt') - res_1D_orig['reference_melt'].sel(metrics='Gt')
diff_box1_orig = res_1D_orig['predicted_melt'].sel(metrics='box1') - res_1D_orig['reference_melt'].sel(metrics='box1')

In [None]:
diff_Gt_all = shuffle_Gt_all['predicted_melt'] - shuffle_Gt_all['reference_melt']
diff_box1_all = (shuffle_box1_all['predicted_melt'] - shuffle_box1_all['reference_melt']).mean('time')

In [None]:
RMSE_Gt_all = np.sqrt((diff_Gt_all**2).mean(['time','Nisf']))
RMSE_box1_all = np.sqrt((diff_box1_all**2).mean(['Nisf']))

In [None]:
RMSE_Gt_orig = np.sqrt((diff_Gt_orig**2).mean(['time','Nisf']))
RMSE_box1_orig = np.sqrt((diff_box1_orig**2).mean(['time','Nisf']))

In [None]:
diff_RMSE_Gt = RMSE_Gt_all - RMSE_Gt_orig
diff_RMSE_box1 = RMSE_box1_all - RMSE_box1_orig



fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for vv in var_list[::-1]:
    
    axs[0].scatter(diff_RMSE_Gt.sel(shuffled_var=vv),vv,marker='o')
    axs[1].scatter(diff_RMSE_box1.sel(shuffled_var=vv),vv,marker='o')

#axs[0].set_xlim(0,20)
sns.despine()


In [None]:
sub_varlist = ['dGL', 'dIF', 'corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
       'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
       'entry_depth_max', 'isfdraft_conc', 'u_tide','water_column']

In [None]:
len(sub_varlist)

In [None]:
sns.heatmap(diff_RMSE_Gt.sel(shuffled_var=var_list).round(2).expand_dims(dim={"dim1": 1}).T)

In [None]:

plt.figure()
sns.heatmap(abs(diff_RMSE_Gt.sel(shuffled_var=var_list).round(2).expand_dims(dim={"dim1": 1}).T), annot=True, center=0, yticklabels=var_list, cmap=mpl.cm.Reds) #


In [None]:

plt.figure()
sns.heatmap(abs(diff_RMSE_box1.sel(shuffled_var=var_list).round(2).expand_dims(dim={"dim1": 1}).T), annot=True, center=0, yticklabels=var_list, cmap=mpl.cm.Reds) #


In [None]:

plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=sub_varlist).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=sub_varlist)


In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_Gt_yr_'+timetag+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_box1_'+timetag+'.png')

In [None]:
diff_RMSE_Gt.sel(shuffled_var='T_profiles')

In [None]:
plot_path