In [None]:
"""
Created on Mon Mar 13 17:28 2023

Look at history and model files

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import pandas as pd
import glob
import matplotlib as mpl
import seaborn as sns

READ IN DATA

In [None]:
inputpath_models = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/grid_search_lrelu/'
outputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_bi646/'
plot_path = '/bettik/PLOTS/'

CV time

In [None]:
hist_test = pd.read_csv(inputpath_models + 'CV_TBLOCK/history_1_32_newbasic_noisf000_notblock001_TSextrap_normstd.csv')

In [None]:
hist_test['val_loss'].min()

In [None]:
da_list = []
for layer_nb in range(1,7):
    size_list = []
    for layer_size in [32,64,96,128,256]:
        min_list = []
        nisf_list = []
        file_list = glob.glob(inputpath_models + 'CV_ISF/history_'+str(layer_nb)+'_'+str(layer_size)+'_newbasic_noisf*_notblock*_TSextrap_normstd.csv')
        for ff in file_list:
            hist_test = pd.read_csv(ff)
            min_loss = hist_test['val_loss'].min()
            kisf = int(ff.split('/')[-1].split('.')[0].split('isf')[1].split('_')[0])
            nisf_list.append(kisf)
            min_list.append(min_loss)
        layer_da = xr.DataArray(data=np.array(min_list), dims=['Nisf']).assign_coords({'layer_size':layer_size,'Nisf': nisf_list})
        size_list.append(layer_da)
    

    size_da = xr.concat(size_list, dim='layer_size')
    da_list.append(size_da.assign_coords({'layer_amount':layer_nb}))

all_mse = xr.concat(da_list, dim='layer_amount')

In [None]:
all_mse.mean('Nisf').plot()

In [None]:
ls, la = xr.broadcast(all_mse.layer_size, all_mse.layer_amount)
plt.scatter(ls,la,c=all_mse.mean('Nisf').values.T, cmap=mpl.cm.autumn_r)
plt.colorbar()

In [None]:
sns.heatmap(all_mse.mean('Nisf'), annot=True) 


In [None]:
da_list = []
for layer_nb in range(1,7):
    size_list = []
    for layer_size in [32,64,96,128,256]:
        min_list = []
        tblock_list = []
        file_list = glob.glob(inputpath_models + 'CV_TBLOCK/history_'+str(layer_nb)+'_'+str(layer_size)+'_newbasic_noisf*_notblock*_TSextrap_normstd.csv')
        for ff in file_list:
            hist_test = pd.read_csv(ff)
            min_loss = hist_test['val_loss'].min()
            tblock = int(ff.split('/')[-1].split('.')[0].split('tblock')[1].split('_')[0])
            tblock_list.append(tblock)
            min_list.append(min_loss)
        layer_da = xr.DataArray(data=np.array(min_list), dims=['tblock']).assign_coords({'layer_size':layer_size,'tblock': tblock_list})
        size_list.append(layer_da)
    

    size_da = xr.concat(size_list, dim='layer_size')
    da_list.append(size_da.assign_coords({'layer_amount':layer_nb}))

all_mse = xr.concat(da_list, dim='layer_amount')

In [None]:
all_mse.mean('tblock').plot()

In [None]:
all_mse

In [None]:
sns.heatmap(all_mse.mean('tblock'), annot=True) 
plt.savefig(plot_path)

In [None]:
ls, la = xr.broadcast(all_mse.layer_size, all_mse.layer_amount)
plt.scatter(ls,la,c=all_mse.mean('tblock').values.T, cmap=mpl.cm.autumn_r)
plt.colorbar()

In [None]:
file_isf = xr.open_dataset(outputpath_mask + 'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_1970.nc')

In [None]:
file_isf['isf_name'].sel(Nisf=38)

In [None]:
isf_names = file_isf['isf_name']

In [None]:
plotted_var = all_rmse

x = np.arange(len(plotted_var.mod_size))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-100, vmax=100)

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)
    
