In [None]:
"""
Created on Wed Apr 20 10:58 2022

Make a matrix with importance of the different variables after shuffling

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
home_path = '/bettik/burgardc/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/input_vars/'

In [None]:
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM006/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekström','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']

In [None]:
nisf_by_reg_list = []
for rr, reg in enumerate(regions):
    subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
    nisf_by_reg_list.append(subset_isf.values)
nisf_by_reg_list = np.concatenate(nisf_by_reg_list)

In [None]:
var_list = ['dGL', 'dIF', 'corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
       'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
       'entry_depth_max', 'isfdraft_conc', 'u_tide', 'T_001', 'T_002',
       'T_003', 'T_004', 'T_005', 'T_006', 'T_007', 'T_008', 'T_009',
       'T_010', 'T_011', 'T_012', 'T_013', 'T_014', 'T_015', 'T_016',
       'T_017', 'T_018', 'T_019', 'T_020', 'T_021', 'T_022', 'T_023',
       'T_024', 'T_025', 'T_026', 'T_027', 'T_028', 'T_029', 'T_030',
       'T_031', 'T_032', 'T_033', 'T_034', 'T_035', 'T_036', 'T_037',
       'T_038', 'T_039', 'T_040', 'T_041', 'T_042', 'T_043', 'T_044',
       'T_045', 'T_046', 'T_047', 'T_048', 'T_049', 'T_050', 'T_051',
       'T_052', 'T_053', 'T_054', 'T_055', 'T_056', 'T_057', 'T_058',
       'T_059', 'T_060', 'T_061', 'T_062', 'T_063', 'T_064', 'T_065',
       'T_066', 'T_067', 'T_068', 'S_001', 'S_002', 'S_003', 'S_004',
       'S_005', 'S_006', 'S_007', 'S_008', 'S_009', 'S_010', 'S_011',
       'S_012', 'S_013', 'S_014', 'S_015', 'S_016', 'S_017', 'S_018',
       'S_019', 'S_020', 'S_021', 'S_022', 'S_023', 'S_024', 'S_025',
       'S_026', 'S_027', 'S_028', 'S_029', 'S_030', 'S_031', 'S_032',
       'S_033', 'S_034', 'S_035', 'S_036', 'S_037', 'S_038', 'S_039',
       'S_040', 'S_041', 'S_042', 'S_043', 'S_044', 'S_045', 'S_046',
       'S_047', 'S_048', 'S_049', 'S_050', 'S_051', 'S_052', 'S_053',
       'S_054', 'S_055', 'S_056', 'S_057', 'S_058', 'S_059', 'S_060',
       'S_061', 'S_062', 'S_063', 'S_064', 'S_065', 'S_066', 'S_067',
       'S_068', 'water_column']

In [None]:
var_list = ['dGL', 'dIF', 'corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
       'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
       'entry_depth_max', 'isfdraft_conc', 'u_tide', 'water_column', 'T_profiles', 'S_profiles']

In [None]:
var_list = ['corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
           'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
           'entry_depth_max', 'isfdraft_conc', 'u_tide', 'theta_in', 'salinity_in', 'water_column']

In [None]:
#run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031']
#run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031-2'] #'OPM031-1',
run_list = ['OPM021','OPM027']  #
#timetag_list = ['20220427-0957','20220427-1002',
#                '20220427-1052','20220427-1021',
#                '20220427-1058','20220427-1042',
#                '20220427-1059','20220427-1051']
timetag_list = ['20220427-1051'] #'20220427-1051' '20220427-1059'

diff_Gt_list = []
diff_box1_list = []

diff_Gt_orig_list = []
diff_box1_orig_list = []

ref_Gt_list = []
ref_box1_list = []


for sv, shuff_var in enumerate(var_list):
    print(shuff_var)
    
    diff_Gt_nrun_list = []
    diff_box1_nrun_list = []

    for n,nemo_run0 in enumerate(run_list):

        if nemo_run0 in ['OPM031-1','OPM031-2']:
            nemo_run = 'OPM031'
        else:
            nemo_run = nemo_run0
        
            
        #print(nemo_run0)
        outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'

        ### READ IN THE REFERENCE
        NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO.nc')
        ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']

        NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO.nc')
        ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']
    
        if sv == 0:
            ref_Gt_list.append(ref_Gt)
            ref_box1_list.append(ref_box1)

        ### READ IN THE PARAM FILES - NON BOOTSTRAP

        # Param files


        diff_Gt_sub_list = []
        diff_box1_sub_list = []

        for timetag in timetag_list:

            outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'
            new_path_output = outputpath_melt_nn+timetag+'/'
            
            #print(nemo_run0)
            ds_melt_param = xr.open_dataset(new_path_output+'eval_metrics_'+nemo_run0+'.nc')

            diff_Gt = ds_melt_param['melt_1D_Gt_per_y'] - ref_Gt
            diff_box1 = ds_melt_param['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')

            if sv == 0:
                diff_Gt_orig_list.append(diff_Gt)
                diff_box1_orig_list.append(diff_box1)    
            
            ds_melt_param = xr.open_dataset(new_path_output+'eval_metrics_'+nemo_run0+'_shuffled'+shuff_var+'.nc')

            diff_Gt = ds_melt_param['melt_1D_Gt_per_y'] - ref_Gt
            diff_Gt_sub_list.append(diff_Gt)

            diff_box1 = ds_melt_param['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')
            diff_box1_sub_list.append(diff_box1)    

        diff_Gt_sub = xr.concat(diff_Gt_sub_list, dim='nn_model')
        diff_Gt_sub = diff_Gt_sub.assign_coords(nn_model=timetag_list)
        diff_box1_sub = xr.concat(diff_box1_sub_list, dim='nn_model') 
        diff_box1_sub = diff_box1_sub.assign_coords(nn_model=timetag_list)
        
        diff_Gt_nrun_list.append(diff_Gt_sub) 
        diff_box1_nrun_list.append(diff_box1_sub) 

    diff_Gt_nrun = xr.concat(diff_Gt_nrun_list, dim='nemo_run')   
    diff_Gt_nrun = diff_Gt_nrun.assign_coords(nemo_run=run_list)
    diff_box1_nrun = xr.concat(diff_box1_nrun_list, dim='nemo_run')   
    diff_box1_nrun = diff_box1_nrun.assign_coords(nemo_run=run_list)
        
    diff_Gt_list.append(diff_Gt_nrun)    
    diff_box1_list.append(diff_box1_nrun)

diff_Gt_all = xr.concat(diff_Gt_list, dim='shuffled_var')   
diff_Gt_all = diff_Gt_all.assign_coords(shuffled_var=var_list)
diff_box1_all = xr.concat(diff_box1_list, dim='shuffled_var')   
diff_box1_all = diff_box1_all.assign_coords(shuffled_var=var_list)

ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_Gt_all = ref_Gt_all.assign_coords(nemo_run=run_list)
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')
ref_box1_all = ref_box1_all.assign_coords(nemo_run=run_list)

diff_Gt_orig = xr.concat(diff_Gt_orig_list, dim='nemo_run')
diff_Gt_orig = diff_Gt_orig.assign_coords(nemo_run=run_list)
diff_box1_orig = xr.concat(diff_box1_orig_list, dim='nemo_run')
diff_box1_orig = diff_box1_orig.assign_coords(nemo_run=run_list)

In [None]:
RMSE_Gt_all = np.sqrt((diff_Gt_all**2).mean(['time','Nisf']))
RMSE_box1_all = np.sqrt((diff_box1_all**2).mean(['Nisf']))

In [None]:
RMSE_Gt_orig = np.sqrt((diff_Gt_orig**2).mean(['time','Nisf']))
RMSE_box1_orig = np.sqrt((diff_box1_orig**2).mean(['Nisf']))

In [None]:
RMSE_Gt_orig

In [None]:
RMSE_Gt_all.sel(shuffled_var='u_tide')

In [None]:
RMSE_Gt_orig

In [None]:
diff_RMSE_Gt = RMSE_Gt_all - RMSE_Gt_orig
diff_RMSE_box1 = RMSE_box1_all - RMSE_box1_orig



fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for var in var_list[::-1]:
    
    for nrun in run_list:

        if nrun == 'OPM006':
            ccolor= 'magenta'
        elif nrun == 'OPM016':
            ccolor= 'orange'
        elif nrun == 'OPM018':
            ccolor= 'brown'
        elif nrun == 'OPM021':
            ccolor = 'red'
        elif nrun == 'OPM026':
            ccolor = 'yellowgreen'
        elif nrun == 'OPM027':
            ccolor = 'deepskyblue'
        elif nrun == 'OPM031-1':
            ccolor = 'blue'
        elif nrun == 'OPM031-2':
            ccolor = 'purple'

        axs[0].scatter(diff_RMSE_Gt.isel(nn_model=0).sel(shuffled_var=var,nemo_run=nrun),var,marker='o',c=ccolor)
        axs[1].scatter(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=var,nemo_run=nrun),var,marker='o',c=ccolor)

#axs[0].set_xlim(0,20)
sns.despine()


In [None]:
sub_varlist = ['dGL', 'dIF', 'corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
       'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
       'entry_depth_max', 'isfdraft_conc', 'u_tide','water_column']

In [None]:
len(sub_varlist)

In [None]:
diff_RMSE_Gt.isel(nn_model=0,nemo_run=0).sel(shuffled_var=sub_varlist).round(2).values.shape

In [None]:

plt.figure()
sns.heatmap(diff_RMSE_Gt.isel(nn_model=0).sel(shuffled_var=sub_varlist).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=sub_varlist)


In [None]:

plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=sub_varlist).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=sub_varlist)


In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_Gt_yr_'+timetag+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_box1_'+timetag+'.png')

In [None]:
diff_RMSE_Gt.sel(shuffled_var='T_profiles')

In [None]:
plot_path