In [None]:
"""
Created on Thu Oct 12 10:17 2023

Look at patterns when shuffling variables

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

In [None]:
%matplotlib qt5

FUNCTIONS

In [None]:
import numpy as np
from matplotlib import cm
import cartopy
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
import matplotlib as mpl
import cmocean
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.ticker import LogFormatterSciNotation
import matplotlib.colors as colors
from colorsys import hls_to_rgb

def sigdigit(a,n):
    """round a to n significant digits

      Examples:
        nico.sigdigit([0.,1.111111,0.],2)          -> array([0. , 1.1, 0. ])
        nico.sigdigit([999.9,1.111111,-323.684],2) -> array([1000. , 1.1, -320. ])
        nico.sigdigit(2.2222222222,3)              -> array([2.22])
        nico.sigdigit(0.,3)                        -> array([0.])
        nico.sigdigit([0.,0.,0.],3)                -> array([0., 0., 0.])

   """
    
    aa=np.array(a)
    masked = aa==0
    bb=np.ones(np.size(aa))
    if np.size(bb[~masked]) != 0:
        bb[~masked]=np.power(10,np.floor(np.log10(np.abs(aa[~masked]))))
        return np.rint(10**(n-1)*aa/bb)*10**(1-n)*bb
    else:
        return bb*0.e0


def smooth(x,window_len=11,window='hanning'):
    """smooth the data using a window with requested size.
    
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    
    input:
        x: the input signal 
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
            flat window will produce a moving average smoothing.

    output:
        the smoothed signal
        
    example:

    t=linspace(-2,2,0.1)
    x=sin(t)+randn(len(t))*0.1
    y=smooth(x)
    
    see also: 
    
    np.hanning, np.hamming, np.bartlett, np.blackman, np.convolve
    scipy.signal.lfilter
 
    TODO: the window parameter could be the window itself if an array instead of a string
    NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y.
    """

    if x.ndim != 1:
        raise(ValueError, "smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        raise(ValueError, "Input vector needs to be bigger than window size.")

    if window_len<3:
        return x

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise(ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")


    sx = np.size(x)
    s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]]
    if window == 'flat': #moving average
        w=np.ones(window_len,'d')
    else:
        w=eval('np.'+window+'(window_len)')

    #y=np.convolve(w/w.sum(),s,mode='valid')
    y=np.convolve(w/w.sum(),s,mode='same')
    return y[np.size(x[window_len-1:0:-1]):np.size(x[window_len-1:0:-1])+sx]

#===========================================================================
# Local functions to handle symmetric-log color bars:

def symlog_transform(linthresh,linscale, a):
    """Inplace transformation."""
    linscale_adj = (linscale / (1.0 - np.e ** -1))
    with np.errstate(invalid="ignore"):
      masked = np.abs(a) > linthresh
    sign = np.sign(a[masked])
    log = (linscale_adj + np.log(np.abs(a[masked]) / linthresh))
    log *= sign * linthresh
    a[masked] = log
    a[~masked] *= linscale_adj
    return a

def symlog_inv_transform(linthresh,linscale, a):
    """Inverse inplace Transformation."""
    linscale_adj = (linscale / (1.0 - np.e ** -1))
    masked = np.abs(a) > (linthresh * linscale_adj)
    sign = np.sign(a[masked])
    exp = np.exp(sign * a[masked] / linthresh - linscale_adj)
    exp *= sign * linthresh
    a[masked] = exp
    a[~masked] /= linscale_adj
    return a

def map_with_contourf_coolwarm(melt_2D, grounded_msk, icesheet_msk, mparam):
    fig, ax = plt.subplots()
    fig.set_size_inches(8.25/1.3, 8.25/1.5/1.25)

    # Customize colormap :
    # NB: modify the Ncool to Nwarm ratio (total=256) to place zero as desired.
    Ncool=86
    Nwarm=256-Ncool
    #------------------------------------------
    # Defining IPCC colormap:
    #LinL = np.loadtxt('IPCC_cryo_div.txt')
    LinL = np.loadtxt(inputpath_colorbar+'IPCC_cryo_div.txt')
    LinL = LinL*0.01
    #
    b3=LinL[:,2] # value of blue at sample n
    b2=LinL[:,2] # value of blue at sample n
    b1=np.linspace(0,1,len(b2)) # position of sample n - ranges from 0 to 1
    # setting up columns for list
    g3=LinL[:,1]
    g2=LinL[:,1]
    g1=np.linspace(0,1,len(g2))
    r3=LinL[:,0]
    r2=LinL[:,0]
    r1=np.linspace(0,1,len(r2))
    # creating list
    R=zip(r1,r2,r3)
    G=zip(g1,g2,g3)
    B=zip(b1,b2,b3)
    # transposing list
    RGB=zip(R,B,G)
    rgb=zip(*RGB)
    # print rgb
    # creating dictionary
    k=['red', 'green', 'blue']
    LinearL=dict(zip(k,rgb)) # makes a dictionary from 2 lists
    ipcc_cmap=mpl.colors.LinearSegmentedColormap('ipcc',LinearL,256)
    #---------------------------------
    # moving the zero of colorbar
    cool = cm.get_cmap(cm.coolwarm_r, Ncool)
    tmp1 = cool(np.linspace(0.5, 0.85, Ncool)) # decrease 0.70 to have more white in the middle light-blue colors
    print(tmp1.shape)
    warm = cm.get_cmap(cm.coolwarm_r, Nwarm)
    tmp2 = warm(np.linspace(0, 0.5, Nwarm)) # increase 0.20 to have more white in the middle light-yellow colors
    print(tmp2.shape)
    newcolors = np.append(tmp1[::-1,:],tmp2[::-1,:],axis=0)
    newcmp = ListedColormap(newcolors)

    # extreme color range values and corresponding tick levels of the symmetric-log contourf levels:
    minval=-5.0
    maxval=135.0
    lin_threshold=1.0
    lin_scale=1.0
    [min_exp,max_exp]=symlog_transform(lin_threshold,lin_scale,np.array([minval,maxval]))
    lev_exp = np.arange( np.floor(min_exp),  np.ceil(max_exp)+1 )
    levs = symlog_inv_transform(lin_threshold,lin_scale,lev_exp)
    levs = sigdigit(levs,2)

    cax=ax.contourf(ref_melt_2D.x,ref_melt_2D.y,melt_2D,levs,cmap=newcmp,norm=mpl.colors.SymLogNorm(linthresh=lin_threshold, linscale=lin_scale,vmin=minval, vmax=maxval),zorder=0)
    #ax.contour(ref_melt_2D.x,ref_melt_2D.y,basnb,np.linspace(0.5,20.5,21),linewidths=0.5,colors='gray',zorder=5)
    ax.contour(ref_melt_2D.x,ref_melt_2D.y,grounded_msk,linewidths=0.5,colors='black',zorder=10)
    ax.contour(ref_melt_2D.x,ref_melt_2D.y,icesheet_msk,linewidths=0.5,colors='black',zorder=15)
    #ax.contour(ref_melt_2D.x,ref_melt_2D.y,box_msk,linewidths=0.5,colors='blue',zorder=10)

    # Zoom on Amundsen:
    zoomfac=2.85
    xll_ori = -2000e3
    yll_ori =  -900e3
    xur_ori = -1450e3
    yur_ori =  -150e3
    xll_des =   -50e3
    yll_des =  -500e3
    xur_des = xll_des + zoomfac * (xur_ori-xll_ori)
    yur_des = yll_des + zoomfac * (yur_ori-yll_ori)
    ax.plot([xll_ori, xur_ori, xur_ori, xll_ori, xll_ori],[yll_ori, yll_ori, yur_ori, yur_ori, yll_ori],'k',linewidth=0.6,zorder=20)
    ax.fill([xll_des, xur_des, xur_des, xll_des, xll_des],[yll_des, yll_des, yur_des, yur_des, yll_des],'w',edgecolor='k',zorder=25)

    i1=np.argmin(np.abs(ref_melt_2D.x.values-xll_ori))
    i2=np.argmin(np.abs(ref_melt_2D.x.values-xur_ori))+1
    j1=np.argmin(np.abs(ref_melt_2D.y.values-yll_ori))
    j2=np.argmin(np.abs(ref_melt_2D.y.values-yur_ori))+1
    xzoom= xll_des + zoomfac * (ref_melt_2D.x-xll_ori)
    yzoom= yll_des + zoomfac * (ref_melt_2D.y-yll_ori)

    print(i1, i2, j1, j2)
    print(np.shape(ref_melt_2D.values), np.shape(xzoom.values))
    ax.contourf(xzoom.isel(x=range(i1,i2)),yzoom.isel(y=range(j2,j1)),melt_2D.isel(x=range(i1,i2),y=range(j2,j1)),levs,cmap=newcmp,norm=mpl.colors.SymLogNorm(linthresh=lin_threshold, linscale=lin_scale,vmin=minval, vmax=maxval),zorder=30)
    ax.contour(xzoom.isel(x=range(i1,i2)),yzoom.isel(y=range(j2,j1)),grounded_msk.isel(x=range(i1,i2),y=range(j2,j1)),linewidths=0.5,colors='black',zorder=30)
    ax.contour(xzoom.isel(x=range(i1,i2)),yzoom.isel(y=range(j2,j1)),icesheet_msk.isel(x=range(i1,i2),y=range(j2,j1)),linewidths=0.5,colors='black',zorder=40)
    #ax.contour(xzoom.isel(x=range(i1,i2)),yzoom.isel(y=range(j2,j1)),box_msk.isel(x=range(i1,i2),y=range(j2,j1)),linewidths=0.15,colors='blue',zorder=35)
    ax.plot([xll_des, xur_des, xur_des, xll_des, xll_des],[yll_des, yll_des, yur_des, yur_des, yll_des],'k',linewidth=1.0,zorder=45)

    #-----

    ratio=1.00
    ax.set_aspect(1.0/ax.get_data_ratio()*ratio)

    # colorbar :
    formatter = LogFormatterSciNotation(10, labelOnlyBase=False, minor_thresholds=(np.inf, np.inf)) # "(np.inf, np.inf)" so that all ticks will be labeled 
    cbar = fig.colorbar(cax, format=formatter, fraction=0.035, pad=0.02, ticks=levs)
    cbar.ax.set_title('m ice/yr') #,size=8
    cbar.outline.set_linewidth(0.3)
    cbar.ax.tick_params(which='both') #labelsize=6,

    #-----

    ax.set_xlim(-2800e3,2800e3)
    ax.set_ylim(-2300e3,2300e3)
    ax.set_title(mparam)
    
    plt.tight_layout()
    return fig

def myround(x, base=5):
    return (base * np.ceil(x/base)).astype(int)

def get_distinct_colors(n):

    colors = []

    for i in np.arange(0., 360., 360. / n):
        h = i / 360.
        l = (50 + np.random.rand() * 10) / 100.
        s = (90 + np.random.rand() * 10) / 100.
        colors.append(hls_to_rgb(h, l, s))

    return colors

In [None]:
def defcolorpalette(ncolors, cmap = 'Accent'):
    colmap = cm.get_cmap(cmap)
    palette = [None]*ncolors
    for i in range(ncolors):
        palette[i] = colmap(float(i)/(ncolors-1.))
    return palette
number_of_colors = 6
palette = defcolorpalette(number_of_colors)
def show_color_palette(palette):
    plt.figure()
    plt.hist(np.ones((1, number_of_colors)), color = palette)
    plt.xlim([1., 1.1])
    plt.gca().xaxis.set_visible(False)
    plt.gca().yaxis.set_visible(False)
show_color_palette(palette)
new_palette = [palette[0],palette[3],palette[4],palette[1],palette[2],palette[5]]
show_color_palette(new_palette)

READ IN DATA

In [None]:
nemo_run =  'bf663' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' #'extrap_shuffboth' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2'#'onlyTSdraftandslope' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL
mod_size = 'small'

In [None]:
home_path = '/bettik/burgardc/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/input_vars/'


In [None]:
var_list = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std','position','watercolumn','slopesbed','slopesice','Tinfo','Sinfo']

In [None]:
merged_var_list = []

ground_list = []
icesheet_list = []
box1_list = []
isf_mask_list = []
melt_list = []
melt_ref_list = []
melt_predic_list = []

inputpath_mask = '/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
inputpath_colorbar = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/raw/MASK_METADATA/'
outputpath_melt = '/bettik/burgardc/DATA/NN_PARAM/interim/MELT_RATE/SMITH_'+nemo_run+'/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/2D_patterns/'
inputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'/'
outputpath_melt_classic = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'_CLASSIC/'

for yy in tqdm(range(1980, 1980 + 60)):

    file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_'+str(yy)+'.nc')
    nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
    file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
    large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
    file_isf = file_isf_nonnan.sel(Nisf=large_isf)
    file_isf_mask = file_isf['ISF_mask'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')
    isf_mask_list.append(file_isf_mask)

    grounded_msk03 = file_isf['ground_mask'].where(file_isf['ground_mask']==0,3)
    grounded_msk = (grounded_msk03.where(grounded_msk03!=3,1)-1)*-1
    ground_list.append(grounded_msk)

    icesheet_msk_0inf = file_isf_mask.where(file_isf_mask!=1,0)
    icesheet_msk = icesheet_msk_0inf.where(icesheet_msk_0inf < 1, 1)
    icesheet_list.append(icesheet_msk)

    box_charac_all_2D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_2D_oneFRIS_'+str(yy)+'_merged75.nc')
    box_charac_all_1D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_1D_oneFRIS_'+str(yy)+'_merged75.nc')

    box_loc_config2 = box_charac_all_2D['box_location'].sel(box_nb_tot=box_charac_all_1D['nD_config'].sel(config=2))
    box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')
    box1_msk = box1.where(box1==1,0)
    box1_list.append(box1_msk)

    melt_ref_2D = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_ensmean_extrap_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
    melt_ref_list.append(melt_ref_2D['reference_melt'])
    melt_predic_2D = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_ensmean_extrap_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
    melt_predic_list.append(melt_ref_2D['predicted_melt'])
    
    melt_yy_list = []
    for vv in var_list:
        pattern_2D_vv = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_shuffled'+vv+'_2D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_'+str(yy)+'_'+nemo_run+'.nc')
        melt_yy_list.append(pattern_2D_vv['predicted_melt'].to_dataset().assign_coords({'shuff_var': vv}))

    melt_yy_all = xr.concat(melt_yy_list, dim='shuff_var')
    melt_list.append(melt_yy_all.chunk({'shuff_var':5}))

ground_msk_all = xr.concat(ground_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
icesheet_msk_all = xr.concat(icesheet_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
box1_msk_all = xr.concat(box1_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})
isf_mask_all = xr.concat(isf_mask_list, dim='time').sel(time=1980+55).assign_coords({'nemo_run': nemo_run})

In [None]:
melt2D_all = xr.concat(melt_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})

In [None]:
melt_ref_all = xr.concat(melt_ref_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})
melt_predic_all = xr.concat(melt_predic_list, dim='time').mean('time').assign_coords({'nemo_run': nemo_run})

merged_vars = xr.merge([ground_msk_all, icesheet_msk_all.rename('ice_mask'), box1_msk_all.rename('box1_mask'), isf_mask_all, melt2D_all])
merged_var_list.append(merged_vars)

var_of_int = xr.concat(merged_var_list, dim='nemo_run')

In [None]:
diff_permuted = (var_of_int['predicted_melt'] - melt_predic_all).isel(nemo_run=0)
diff_permuted_abs = abs(diff_permuted)

In [None]:
for vv in diff_permuted.shuff_var:
    plt.figure()
    diff_permuted.sel(shuff_var=vv).plot()
    plt.title(vv.values)

In [None]:
#ALL VARIABLES
var_single = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']
ccolors = get_distinct_colors(14)

idx_MAE = max_MAE * np.nan
max_MAE = diff_permuted_abs.sel(shuff_var=var_single).max('shuff_var')
for i,vv in enumerate(var_single):
    idx_MAE = idx_MAE.where(diff_permuted_abs.sel(shuff_var=vv) != max_MAE, i+1)
    
max_MAE.plot(cmap=mpl.cm.Reds)

plt.figure()
cmap = mpl.colors.ListedColormap(ccolors)
idx_MAE.plot(cmap=cmap)

In [None]:
# VARIABLES SUBSET
var_subset = ['position','watercolumn','slopesbed','slopesice','Tinfo','Sinfo']
max_MAE = diff_permuted_abs.sel(shuff_var=var_subset).max('shuff_var')

idx_MAE = max_MAE * np.nan
for i,vv in enumerate(var_subset):
    idx_MAE = idx_MAE.where(diff_permuted_abs.sel(shuff_var=vv) != max_MAE, i+1)
    
max_MAE.plot(cmap=mpl.cm.Reds)

plt.figure()
cmap = mpl.colors.ListedColormap(palette)
idx_MAE.plot(cmap=cmap)

In [None]:
fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    max_kisf = max_MAE.where(kisf_mask,drop=True)
    max_kisf.plot(ax=ax[i],cmap=mpl.cm.Reds, add_colorbar=False,levels=range(35))
    ax[i].contour(max_kisf.x,max_kisf.y,var_of_int['box1_mask'].isel(nemo_run=0).where(max_kisf),levels=[0,1],linewidths=2,colors='grey',linestyles='--',zorder=10)
    ax[i].contour(max_kisf.x,max_kisf.y,var_of_int['ground_mask'].isel(nemo_run=0).where(max_kisf),levels=[0,1],linewidths=2,colors='black',zorder=10)
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')
    
plt.tight_layout()
fig.savefig(plot_path+'max_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'.png', dpi=300)


In [None]:

fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    max_kisf = max_MAE.where(kisf_mask,drop=True)
    max_kisf.plot(ax=ax[i],cmap=mpl.cm.Reds,levels=range(35),cbar_kwargs={"location": "bottom"})
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')
    
plt.tight_layout()
fig.savefig(plot_path+'max_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'_withcolorbar.png', dpi=300)


In [None]:

cmaph = mpl.colors.ListedColormap(new_palette)
    
fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    idx_kisf = idx_MAE.where(kisf_mask,drop=True)
    idx_kisf.plot(ax=ax[i],cmap=cmaph, add_colorbar=False)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['box1_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',linestyles='--',zorder=10)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['ground_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',zorder=10)
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')


plt.tight_layout()
fig.savefig(plot_path+'idx_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'.png', dpi=300)


In [None]:

cmaph = mpl.colors.ListedColormap(new_palette)
    
fig, ax = plt.subplots(1,5)
fig.set_size_inches(8.25*2.5, 8.25/2)

for i,kisf in enumerate([10,11,66,31,44]):
    kisf_mask = var_of_int['ISF_mask'].isel(nemo_run=0)==kisf
    idx_kisf = idx_MAE.where(kisf_mask,drop=True)
    idx_kisf.plot(ax=ax[i],cmap=cmaph, cbar_kwargs={"location": "bottom"})
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['box1_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',linestyles='--',zorder=10)
    ax[i].contour(idx_kisf.x,idx_kisf.y,var_of_int['ground_mask'].isel(nemo_run=0).where(idx_kisf),levels=[0,1],linewidths=2,colors='black',zorder=10)
    
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    ax[i].set_title('')


plt.tight_layout()
fig.savefig(plot_path+'idx_MAE_shuffled2D_'+nemo_run+'_'+TS_opt+'_withcolorbar.png', dpi=300)


In [None]:
max_MAE.where(kisf_mask,drop=True).plot()

In [None]:
sub_varlist = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
len(sub_varlist)

In [None]:
diff_RMSE_Gt_okvar = diff_RMSE_Gt.sel(shuffled_var=var_list)
diff_RMSE_Gt_norm = (diff_RMSE_Gt_okvar) / (abs(diff_RMSE_Gt_okvar).max('shuffled_var'))

In [None]:
diff_RMSE_box1_okvar = diff_RMSE_box1.sel(shuffled_var=var_list)
diff_RMSE_box1_norm = (diff_RMSE_box1_okvar) / (abs(diff_RMSE_box1_okvar).max('shuffled_var'))

In [None]:
plt.figure(figsize=(8.24/1.5,8.24/1.25))
sns.heatmap(diff_RMSE_Gt.sel(shuffled_var=var_list).round(1).T, annot=True, fmt="g", yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm, cbar=False) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_Gt_bothmodels_'+TS_opt+'_'+nemo_run+'_'+exp_name+'.pdf')

In [None]:
plt.figure(figsize=(8.24/1.5,8.24/1.25))
sns.heatmap(diff_RMSE_box1.sel(shuffled_var=var_list).round(2).T, annot=True, fmt="g", yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm, cbar=False) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_box1_bothmodels_'+TS_opt+'_'+nemo_run+'_'+exp_name+'.pdf')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt.sel(shuffled_var=sub_varlist).round(2).T, annot=True, fmt='d', yticklabels=sub_varlist, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_Gt_bothmodels_subvar_'+exp_name+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1.sel(shuffled_var=sub_varlist).round(2).T, annot=True, fmt='d', yticklabels=sub_varlist, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_box1_bothmodels_subvar_'+exp_name+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt_norm.sel(shuffled_var=var_list).round(2).T, annot=True, fmt='d', yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_Gt_norm_'+mod_size+'_'+exp_name+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1_norm.sel(shuffled_var=var_list).round(2).T, annot=True, fmt='d', yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'mixedpermutation_importance_box1_norm_'+mod_size+'_'+exp_name+'.png')

REMOVE LARGE ONES

In [None]:
RMSE_Gt_all = np.sqrt((diff_Gt_all**2).drop_sel(Nisf=[4,10,11]).mean(['time','Nisf']))
RMSE_box1_all = np.sqrt((diff_box1_all**2).drop_sel(Nisf=[4,10,11]).mean(['Nisf']))

In [None]:
RMSE_Gt_orig = np.sqrt((diff_Gt_orig**2).drop_sel(Nisf=[4,10,11]).mean(['time','Nisf']))
RMSE_box1_orig = np.sqrt((diff_box1_orig**2).drop_sel(Nisf=[4,10,11]).mean(['Nisf']))

In [None]:
diff_RMSE_Gt = RMSE_Gt_all - RMSE_Gt_orig
diff_RMSE_box1 = RMSE_box1_all - RMSE_box1_orig

In [None]:
diff_RMSE_Gt_okvar = diff_RMSE_Gt.sel(shuffled_var=var_list)
diff_RMSE_Gt_norm = (diff_RMSE_Gt_okvar) / (abs(diff_RMSE_Gt_okvar).max('shuffled_var'))

In [None]:
diff_RMSE_box1_okvar = diff_RMSE_box1.sel(shuffled_var=var_list)
diff_RMSE_box1_norm = (diff_RMSE_box1_okvar) / (abs(diff_RMSE_box1_okvar).max('shuffled_var'))

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt.sel(shuffled_var=var_list).round(2).T, annot=True, yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'permutation_importance_Gt_'+mod_size+'_'+exp_name+'_wolargeones.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1.sel(shuffled_var=var_list).round(2).T, annot=True, yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'permutation_importance_box1_'+mod_size+'_'+exp_name+'_wolargeones.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt_norm.sel(shuffled_var=var_list).round(2).T, annot=True, yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'permutation_importance_Gt_norm_'+mod_size+'_'+exp_name+'_wolargeones.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1_norm.sel(shuffled_var=var_list).round(2).T, annot=True, yticklabels=var_list, center=0, cmap=mpl.cm.coolwarm) #, cmap=mpl.cm.Reds
plt.savefig(plot_path+'permutation_importance_box1_norm_'+mod_size+'_'+exp_name+'_wolargeones.png')

In [None]:

plt.figure()
sns.heatmap(abs(diff_RMSE_Gt.sel(shuffled_var=var_list).round(2).T), annot=True, center=0, yticklabels=var_list) #, cmap=mpl.cm.Reds


In [None]:

plt.figure()
#sns.heatmap(abs(diff_RMSE_box1.sel(shuffled_var=var_list).round(2).expand_dims(dim={"dim1": 1}).T), annot=True, center=0, yticklabels=var_list, cmap=mpl.cm.Reds) #
sns.heatmap(abs(diff_RMSE_box1.sel(shuffled_var=var_list).round(2).T), annot=True, center=0, yticklabels=var_list) #cmap="YlOrBr"


In [None]:

plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=sub_varlist).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=sub_varlist)


In [None]:
plt.figure()
sns.heatmap(diff_RMSE_Gt.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_Gt_yr_'+timetag+'.png')

In [None]:
plt.figure()
sns.heatmap(diff_RMSE_box1.isel(nn_model=0).sel(shuffled_var=var_list).round(2).rename('diff_RMSE'), annot=True, center=0, cmap=mpl.cm.Reds, yticklabels=var_list, xticklabels=run_list)
plt.savefig(plot_path+'permutation_importance_box1_'+timetag+'.png')

In [None]:
diff_RMSE_Gt.sel(shuffled_var='T_profiles')

In [None]:
plot_path