In [None]:
"""
Created on Thu Oct 12 10:17 2023

Look at patterns when shuffling variables => FIGURE 6

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
from matplotlib import cm


In [None]:
%matplotlib qt5

FUNCTIONS

In [None]:
def defcolorpalette(ncolors, cmap = 'Accent'):
    colmap = cm.get_cmap(cmap)
    palette = [None]*ncolors
    for i in range(ncolors):
        palette[i] = colmap(float(i)/(ncolors-1.))
    return palette
number_of_colors = 6
palette = defcolorpalette(number_of_colors)
def show_color_palette(palette):
    plt.figure()
    plt.hist(np.ones((1, number_of_colors)), color = palette)
    plt.xlim([1., 1.1])
    plt.gca().xaxis.set_visible(False)
    plt.gca().yaxis.set_visible(False)
show_color_palette(palette)
new_palette = [palette[0],palette[3],palette[4],palette[1],palette[2],palette[5]]
show_color_palette(new_palette)

READ IN DATA

In [None]:
nemo_run =  'bf663' #'bi646'
TS_opt = 'extrap' #'extrap_shuffboth' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2'#'onlyTSdraftandslope' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL
mod_size = 'small'

In [None]:
home_path = '/bettik/burgardc/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/input_vars/'


In [None]:
var_list = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std','position','watercolumn','slopesbed','slopesice','Tinfo','Sinfo']

In [None]:
merged_var_list = []

ground_list = []
icesheet_list = []
box1_list = []
isf_mask_list = []
melt_list = []
melt_ref_list = []
melt_predic_list = []

inputpath_mask = '/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
inputpath_colorbar = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/raw/MASK_METADATA/'
outputpath_melt = '/bettik/burgardc/DATA/NN_PARAM/interim/MELT_RATE/SMITH_'+nemo_run+'/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/2D_patterns/'
inputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'/'
outputpath_melt_classic = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'_CLASSIC/'

for yy in tqdm(range(1980, 1980 + 60)):

    file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_'+str(yy)+'.nc')
    nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
    file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
    large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
    file_isf = file_isf_nonnan.sel(Nisf=large_isf)
    file_isf_mask = file_isf['ISF_mask'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')
    isf_mask_list.append(file_isf_mask)

    grounded_msk03 = file_isf['ground_mask'].where(file_isf['ground_mask']==0,3)
    grounded_msk = (grounded_msk03.where(grounded_msk03!=3,1)-1)*-1
    ground_list.append(grounded_msk)

    icesheet_msk_0inf = file_isf_mask.where(file_isf_mask!=1,0)
    icesheet_msk = icesheet_msk_0inf.where(icesheet_msk_0inf < 1, 1)
    icesheet_list.append(icesheet_msk)

    box_charac_all_2D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_2D_oneFRIS_'+str(yy)+'_merged75.nc')
    box_charac_all_1D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_1D_oneFRIS_'+str(yy)+'_merged75.nc')

    box_loc_config2 = box_charac_all_2D['box_location'].sel(box_nb_tot=box_charac_all_1D['nD_config'].sel(config=2))
    box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')
    box1_msk = box1.where(box1==1,0)
    box1_list.append(box1_msk)

    melt_ref_2D = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_ensmean_extrap_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
    melt_ref_list.append(melt_ref_2D['reference_melt'])
    melt_predic_2D = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_ensmean_extrap_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
    melt_predic_list.append(melt_ref_2D['predicted_melt'])
    
    melt_yy_list = []
    for vv in var_list:
        pattern_2D_vv = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_shuffled'+vv+'_2D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
        melt_yy_list.append(pattern_2D_vv['predicted_melt'].to_dataset().assign_coords({'shuff_var': vv}))

    melt_yy_all = xr.concat(melt_yy_list, dim='shuff_var')
    melt_list.append(melt_yy_all.chunk({'shuff_var':5}))

ground_msk_all = xr.concat(ground_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
icesheet_msk_all = xr.concat(icesheet_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
box1_msk_all = xr.concat(box1_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
isf_mask_all = xr.concat(isf_mask_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})

In [None]:
melt2D_all = xr.concat(melt_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})

average over time

In [None]:
melt_ref_all = xr.concat(melt_ref_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})
melt_predic_all = xr.concat(melt_predic_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})

merged_vars = xr.merge([ground_msk_all, icesheet_msk_all.rename('ice_mask'), box1_msk_all.rename('box1_mask'), isf_mask_all, melt2D_all])
merged_var_list.append(merged_vars)

var_of_int = xr.concat(merged_var_list, dim='nemo_run')

COMPUTE MEAN ABSOLUTE ERROR BETWEEN PERMUTED AND ORIGINAL

In [None]:
diff_permuted = (var_of_int['predicted_melt'] - melt_predic_all).isel(nemo_run=0)
diff_permuted_abs = abs(diff_permuted)

In [None]:
# VARIABLES SUBSET
var_subset = ['position','watercolumn','slopesbed','slopesice','Tinfo','Sinfo']
max_MAE = diff_permuted_abs.sel(shuff_var=var_subset).max('shuff_var')

idx_MAE = max_MAE * np.nan
for i,vv in enumerate(var_subset):
    idx_MAE = idx_MAE.where(diff_permuted_abs.sel(shuff_var=vv) != max_MAE, i+1)
    
max_MAE.plot(cmap=mpl.cm.Reds)

plt.figure()
cmap = mpl.colors.ListedColormap(palette)
idx_MAE.plot(cmap=cmap)

FIGURE 6

In [None]:

cmaph = mpl.colors.ListedColormap(new_palette)
    
fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    idx_kisf = idx_MAE.where(kisf_mask,drop=True)
    idx_kisf.plot(ax=ax[i],cmap=cmaph, add_colorbar=False)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['box1_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',linestyles='--',zorder=10)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['ground_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',zorder=10)
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')


plt.tight_layout()
#fig.savefig(plot_path+'idx_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'.png', dpi=300)


In [None]:

cmaph = mpl.colors.ListedColormap(new_palette)
    
fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    idx_kisf = idx_MAE.where(kisf_mask,drop=True)
    idx_kisf.plot(ax=ax[i],cmap=cmaph, cbar_kwargs={"location": "bottom"})
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['box1_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',linestyles='--',zorder=10)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['ground_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',zorder=10)
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')


plt.tight_layout()
fig.savefig(plot_path+'idx_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'_withcolorbar.png', dpi=300)
