# OSE metrics

In [None]:
import sys
sys.path.append('..')
import xarray as xr
import numpy as np
import pyinterp
import netCDF4
from src.ose.mod_inout import *
from src.ose.mod_interp import *
from src.ose.mod_stats import *
from src.ose.mod_spectral import *
from src.ose.mod_plot import *
from src.ose.utils import *

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 1. Interactive results for the OI

In [None]:
data = xr.open_dataset('/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L3S_GHRSST-SSTsubskin-night_SST_UHR_NRT-NSEABALTIC_validation_s1.nc',chunks={"time":10})
time = xr.DataArray(data.time.values, dims=("time")).expand_dims(dim={"lat": data.lat.values}, axis=1).expand_dims(dim={"lon": data.lon.values}, axis=2)
lat = xr.DataArray(data.lat.values, dims=("lat")).expand_dims(dim={"time": data.time.values}, axis=0).expand_dims(dim={"lon": data.lon.values}, axis=2)
lon = xr.DataArray(data.lon.values, dims=("lon")).expand_dims(dim={"lat": data.lat.values}, axis=0).expand_dims(dim={"time": data.time.values}, axis=0)

# flatten 
idxs = np.argwhere(np.isfinite(data.sea_surface_temperature.values))
new_time = time.values[idxs[:,0],idxs[:,1],idxs[:,2]]
new_lat = lat.values[idxs[:,0],idxs[:,1],idxs[:,2]]
new_lon = lon.values[idxs[:,0],idxs[:,1],idxs[:,2]]
new_sstobs = data.sea_surface_temperature.values[idxs[:,0],idxs[:,1],idxs[:,2]]
new_data = xr.Dataset(data_vars={'sstobs':(('time'),new_sstobs)},
                      coords={'time':('time',new_time),
                              'latitude':('time',new_lat),
                              'longitude':('time',new_lon)})

In [None]:
def read_l4_dataset(list_of_file, 
                    lon_min=-180., 
                    lon_max=180., 
                    lat_min=-90, 
                    lat_max=90., 
                    sel_time=slice('1900-10-01','2100-01-01'), 
                    byseason=False,
                    is_circle=True,
                    crop=False):
    

    ds = xr.open_mfdataset(list_of_file, concat_dim ='time', combine='nested', parallel=True,
                          chunks={"time":10})
    ds = ds.assign_coords({"season":(('time'),ds.time.dt.season.data)})
    if crop:
        ds.sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))
    if not byseason:
        ds = ds.sel(time=sel_time, drop=True)
    else:
        ds = ds.sel(time=(ds.season==(sel_time)))
    ds = ds.where((ds["lon"] >= lon_min) & (ds["lon"] <= lon_max), drop=True)
    ds = ds.where((ds["lat"] >= lat_min) & (ds["lat"] <= lat_max), drop=True)
        
    x_axis = pyinterp.Axis(ds["lon"][:], is_circle=is_circle)
    y_axis = pyinterp.Axis(ds["lat"][:])
    z_axis = pyinterp.TemporalAxis(ds["time"].values)
    
    var = ds['analysed_sst'][:]
    var = var.transpose('lon', 'lat', 'time')

    # The undefined values must be set to nan.
    '''
    try:
        #var[var.mask] = float("nan")
        var = var.where(var.mask,float("nan"))
    except AttributeError:
        pass
    '''
    grid = pyinterp.Grid3D(x_axis, y_axis, z_axis, var.data)
    
    del ds
    
    return x_axis, y_axis, z_axis, grid


def interp_on_alongtrack(gridded_dataset, 
                         ds_alongtrack,
                         lon_min=-180, 
                         lon_max=180., 
                         lat_min=-90, 
                         lat_max=90., 
                         sel_time=slice('1900-10-01','2100-01-01'),
                         byseason=False,
                         is_circle=True,
                         crop=False):
    # Interpolate maps onto alongtrack dataset
    x_axis, y_axis, z_axis, grid = read_l4_dataset(gridded_dataset,
                                                       lon_min=lon_min,
                                                       lon_max=lon_max, 
                                                       lat_min=lat_min,
                                                       lat_max=lat_max, 
                                                       sel_time=sel_time,
                                                       byseason=byseason,
                                                       is_circle=is_circle,
                                                       crop=crop)

    ds_alongtrack = ds_alongtrack.assign_coords({"season":(('time'),ds_alongtrack.time.dt.season.data)})
    if not byseason:
        ds_alongtrack = ds_alongtrack.sel(time=sel_time, drop=True)
    else:
        ds_alongtrack = ds_alongtrack.sel(time=(ds_alongtrack.season==(sel_time)))  
    sst_map_interp = pyinterp.trivariate(grid, 
                                         ds_alongtrack["longitude"].values, 
                                         ds_alongtrack["latitude"].values,
                                         z_axis.safe_cast(ds_alongtrack.time.values),
                                         bounds_error=False).reshape(ds_alongtrack["longitude"].values.shape)
    
    sst_alongtrack = ds_alongtrack["sstobs"].values
    lon_alongtrack = ds_alongtrack["longitude"].values
    lat_alongtrack = ds_alongtrack["latitude"].values
    time_alongtrack = ds_alongtrack["time"].values
    
    # get and apply mask from map_interp & alongtrack on each dataset
    msk1 = np.ma.masked_invalid(sst_alongtrack).mask
    msk2 = np.ma.masked_invalid(sst_map_interp).mask
    msk = msk1 + msk2
    sst_alongtrack = np.ma.masked_where(msk, sst_alongtrack).compressed()
    lon_alongtrack = np.ma.masked_where(msk, lon_alongtrack).compressed()
    lat_alongtrack = np.ma.masked_where(msk, lat_alongtrack).compressed()
    time_alongtrack = np.ma.masked_where(msk, time_alongtrack).compressed()
    sst_map_interp = np.ma.masked_where(msk, sst_map_interp).compressed()
    
    # select inside value (this is done to insure similar number of point in statistical comparison between methods)
    indices = np.where((lon_alongtrack >= lon_min+0.25) & (lon_alongtrack <= lon_max-0.25) &
                       (lat_alongtrack >= lat_min+0.25) & (lat_alongtrack <= lat_max-0.25))[0]

    
    return time_alongtrack[indices], lat_alongtrack[indices], lon_alongtrack[indices], sst_alongtrack[indices], sst_map_interp[indices]
    

In [None]:
lon_min = -10
lon_max = 30
lat_min = 48
lat_max = 66
time_min = '2021-01-01'
time_max = '2021-12-31'
sel_time = slice(time_min, time_max)
file = "/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc"
res = interp_on_alongtrack(file,
                              new_data,
                              lon_min=lon_min,
                              lon_max=lon_max,
                              lat_min=lat_min,
                              lat_max=lat_max,
                              sel_time=sel_time,
                              is_circle=True)
time_alongtrack, lat_alongtrack, lon_alongtrack, sst_alongtrack, sst_interp = res

In [None]:
os.remove('/DATASET/mbeauchamp/spa_stat.nc')
os.remove('/DATASET/mbeauchamp/TS.nc')  

In [None]:
def compute_stats(time_alongtrack, 
                  lat_alongtrack, 
                  lon_alongtrack, 
                  ssh_alongtrack, 
                  ssh_map_interp, 
                  bin_lon_step,
                  bin_lat_step, 
                  bin_time_step,
                  output_filename,
                  output_filename_timeseries):

    ncfile = netCDF4.Dataset(output_filename,'w')

    binning = pyinterp.Binning2D(
        pyinterp.Axis(np.arange(-180, 180, bin_lon_step), is_circle=True),
        pyinterp.Axis(np.arange(-90, 90 + bin_lat_step, bin_lat_step)))

    # binning alongtrack
    binning.push(lon_alongtrack, lat_alongtrack, ssh_alongtrack, simple=True)
    write_stat(ncfile, 'alongtrack', binning)
    binning.clear()

    # binning map interp
    binning.push(lon_alongtrack, lat_alongtrack, ssh_map_interp, simple=True)
    write_stat(ncfile, 'maps', binning)
    binning.clear()

    # binning diff sla-msla
    binning.push(lon_alongtrack, lat_alongtrack, ssh_alongtrack - ssh_map_interp, simple=True)
    write_stat(ncfile, 'diff', binning)
    binning.clear()

    # add rmse
    diff2 = (ssh_alongtrack - ssh_map_interp)**2
    binning.push(lon_alongtrack, lat_alongtrack, diff2, simple=True)
    var = ncfile.groups['diff'].createVariable('rmse', binning.variable('mean').dtype, ('lat','lon'), zlib=True)
    var[:, :] = np.sqrt(binning.variable('mean')).T  
    
    ncfile.close()
    
    logging.info(f'  Results saved in: {output_filename}')

    # write time series statistics
    leaderboard_nrmse, leaderboard_nrmse_std = write_timeserie_stat(ssh_alongtrack, 
                                                                    ssh_map_interp, 
                                                                    time_alongtrack, 
                                                                    bin_time_step, 
                                                                    output_filename_timeseries)
    
    return leaderboard_nrmse, leaderboard_nrmse_std

In [None]:
# Outputs
bin_lat_step = 0.1
bin_lon_step = 0.1
bin_time_step = '1D'
leaderboard_nrmse, leaderboard_nrmse_std = compute_stats(time_alongtrack, 
                                                         lat_alongtrack, 
                                                         lon_alongtrack, 
                                                         sst_alongtrack, 
                                                         sst_interp, 
                                                         bin_lon_step,
                                                         bin_lat_step, 
                                                         bin_time_step,
                                                         output_filename='/DATASET/mbeauchamp/spa_stat.nc',
                                                         output_filename_timeseries='/DATASET/mbeauchamp/TS.nc')

In [None]:
plot_spatial_statistics('/DATASET/mbeauchamp/spa_stat.nc')

In [None]:
from src.ose.mod_plot import plot_temporal_statistics
plot_temporal_statistics('/DATASET/mbeauchamp/TS.nc')

In [None]:
os.remove('/DATASET/mbeauchamp/spectrum.nc')

In [None]:
def compute_segment_alongtrack(time_alongtrack, 
                               lat_alongtrack, 
                               lon_alongtrack, 
                               ssh_alongtrack, 
                               ssh_map_interp, 
                               lenght_scale,
                               delta_x,
                               delta_t,
                               convert_lon=True):
    
    if convert_lon:
        lon_alongtrack = np.mod(lon_alongtrack,360)

    segment_overlapping = 0.25
    max_delta_t_gap = 4 * np.timedelta64(1, 's')  # max delta t of 4 seconds to cut tracks

    list_lat_segment = []
    list_lon_segment = []
    list_ssh_alongtrack_segment = []
    list_ssh_map_interp_segment = []

    # Get number of point to consider for resolution = lenghtscale in km
    delta_t_jd = delta_t / (3600 * 24)
    npt = int(lenght_scale / delta_x)

    # cut track when diff time longer than 4*delta_t
    indi = np.where((np.diff(time_alongtrack) > max_delta_t_gap))[0]
    track_segment_lenght = np.insert(np.diff(indi), [0], indi[0])

    # Long track >= npt
    selected_track_segment = np.where(track_segment_lenght >= npt)[0]

    if selected_track_segment.size > 0:

        for track in selected_track_segment:

            if track-1 >= 0:
                index_start_selected_track = indi[track-1]
                index_end_selected_track = indi[track]
            else:
                index_start_selected_track = 0
                index_end_selected_track = indi[track]

            start_point = index_start_selected_track
            end_point = index_end_selected_track

            for sub_segment_point in range(start_point, end_point - npt, int(npt*segment_overlapping)):

                # Near Greenwhich case
                if ((lon_alongtrack[sub_segment_point + npt - 1] < 50.)
                    and (lon_alongtrack[sub_segment_point] > 320.)) \
                        or ((lon_alongtrack[sub_segment_point + npt - 1] > 320.)
                            and (lon_alongtrack[sub_segment_point] < 50.)):

                    tmp_lon = np.where(lon_alongtrack[sub_segment_point:sub_segment_point + npt] > 180,
                                       lon_alongtrack[sub_segment_point:sub_segment_point + npt] - 360,
                                       lon_alongtrack[sub_segment_point:sub_segment_point + npt])
                    mean_lon_sub_segment = np.median(tmp_lon)

                    if mean_lon_sub_segment < 0:
                        mean_lon_sub_segment = mean_lon_sub_segment + 360.
                else:

                    mean_lon_sub_segment = np.median(lon_alongtrack[sub_segment_point:sub_segment_point + npt])

                mean_lat_sub_segment = np.median(lat_alongtrack[sub_segment_point:sub_segment_point + npt])

                ssh_alongtrack_segment = np.ma.masked_invalid(ssh_alongtrack[sub_segment_point:sub_segment_point + npt])

                ssh_map_interp_segment = []
                ssh_map_interp_segment = np.ma.masked_invalid(ssh_map_interp[sub_segment_point:sub_segment_point + npt])
                if np.ma.is_masked(ssh_map_interp_segment):
                    ssh_alongtrack_segment = np.ma.compressed(np.ma.masked_where(np.ma.is_masked(ssh_map_interp_segment), ssh_alongtrack_segment))
                    ssh_map_interp_segment = np.ma.compressed(ssh_map_interp_segment)

                if ssh_alongtrack_segment.size > 0:
                    list_ssh_alongtrack_segment.append(ssh_alongtrack_segment)
                    list_lon_segment.append(mean_lon_sub_segment)
                    list_lat_segment.append(mean_lat_sub_segment)
                    list_ssh_map_interp_segment.append(ssh_map_interp_segment)


    return list_lon_segment, list_lat_segment, list_ssh_alongtrack_segment, list_ssh_map_interp_segment, npt 



In [None]:
delta_t = 0.9434  # s
velocity = 6.77   # km/s
delta_x = velocity * delta_t
delta_x = 2.2
lenght_scale = 1000 # km
compute_spectral_scores(time_alongtrack, 
                        lat_alongtrack, 
                        lon_alongtrack, 
                        sst_alongtrack, 
                        sst_interp, 
                        lenght_scale,
                        delta_x,
                        delta_t,
                        '/DATASET/mbeauchamp/spectrum.nc')

In [None]:
def plot_psd_score(filename):
    
    ds = xr.open_dataset(filename)
    
    resolved_scale = find_wavelength_05_crossing(filename)
    
    plt.figure(figsize=(10, 5))
    ax = plt.subplot(121)
    ax.invert_xaxis()
    plt.plot((1./ds.wavenumber), ds.psd_ref, label='reference', color='k')
    plt.plot((1./ds.wavenumber), ds.psd_study, label='reconstruction', color='lime')
    plt.xlabel('wavelength [km]')
    plt.ylabel('Power Spectral Density [m$^{2}$/cy/km]')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend(loc='best')
    plt.grid(which='both')
    
    ax = plt.subplot(122)
    ax.invert_xaxis()
    plt.plot((1./ds.wavenumber), (1. - ds.psd_diff/ds.psd_ref), color='k', lw=2)
    plt.xlabel('wavelength [km]')
    plt.ylabel('PSD Score [1. - PSD$_{err}$/PSD$_{ref}$]')
    plt.xscale('log')
    plt.hlines(y=0.5, 
              xmin=np.ma.min(np.ma.masked_invalid(1./ds.wavenumber)), 
              xmax=np.ma.max(np.ma.masked_invalid(1./ds.wavenumber)),
              color='r',
              lw=0.5,
              ls='--')
    '''
    plt.vlines(x=resolved_scale, ymin=0, ymax=1, lw=0.5, color='g')
    ax.fill_betweenx((1. - ds.psd_diff/ds.psd_ref), 
                     resolved_scale, 
                     np.ma.max(np.ma.masked_invalid(1./ds.wavenumber)),
                     color='green',
                     alpha=0.3, 
                     label=f'resolved scales \n $\lambda$ > {int(resolved_scale)}km')
    '''
    plt.legend(loc='best')
    plt.grid(which='both')
    
    logging.info(' ')
    logging.info(f'  Minimum spatial scale resolved = {int(resolved_scale)}km')
    
    plt.show()
    
    return resolved_scale

In [None]:
leaderboard_psds_score = plot_psd_score('/DATASET/mbeauchamp/spectrum.nc') 

# 2. Perform the intercomparison between OI and 4DVarNet

## 2.1 New set of functions for intercomparison

In [None]:
def run_comparison(file,method, crop=False, byseason=False,compute_spectrum=True,
                   file_obs='/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L3S_GHRSST-SSTsubskin-night_SST_UHR_NRT-NSEABALTIC_validation_s1.nc'):
         
    # prepare L3 datasets
    data = xr.open_dataset(file_obs,chunks={"time":10})
    data = data.sel(lon=slice(lon_min,lon_max),
                    lat=slice(lat_min,lat_max))
    time = xr.DataArray(data.time.values, dims=("time")).expand_dims(dim={"lat": data.lat.values}, axis=1).expand_dims(dim={"lon": data.lon.values}, axis=2)
    lat = xr.DataArray(data.lat.values, dims=("lat")).expand_dims(dim={"time": data.time.values}, axis=0).expand_dims(dim={"lon": data.lon.values}, axis=2)
    lon = xr.DataArray(data.lon.values, dims=("lon")).expand_dims(dim={"lat": data.lat.values}, axis=0).expand_dims(dim={"time": data.time.values}, axis=0)

    # flatten 
    idxs = np.argwhere(np.isfinite(data.sea_surface_temperature.values))
    new_time = time.values[idxs[:,0],idxs[:,1],idxs[:,2]]
    new_lat = lat.values[idxs[:,0],idxs[:,1],idxs[:,2]]
    new_lon = lon.values[idxs[:,0],idxs[:,1],idxs[:,2]]
    new_sstobs = data.sea_surface_temperature.values[idxs[:,0],idxs[:,1],idxs[:,2]]
    data_l3 = xr.Dataset(data_vars={'sstobs':(('time'),new_sstobs)},
                      coords={'time':('time',new_time),
                              'latitude':('time',new_lat),
                              'longitude':('time',new_lon)})

    # interpolation l4 to l3       
    res = interp_on_alongtrack(file,
                              data_l3,
                              lon_min=lon_min,
                              lon_max=lon_max,
                              lat_min=lat_min,
                              lat_max=lat_max,
                              sel_time=sel_time,
                              byseason=byseason,
                              is_circle=True,
                              crop=crop)
    time_alongtrack, lat_alongtrack, lon_alongtrack, sst_alongtrack, sst_interp = res    
    
    fn_rmse = '/DATASET/mbeauchamp/spa_stat_'+method+'.nc'
    fn_ts = '/DATASET/mbeauchamp/TS_'+method+'.nc'
    fn_spectrum =  '/DATASET/mbeauchamp/spectrum_'+method+'.nc'

    leaderboard_nrmse, leaderboard_nrmse_std = compute_stats(time_alongtrack, 
                                                         lat_alongtrack, 
                                                         lon_alongtrack, 
                                                         sst_alongtrack, 
                                                         sst_interp, 
                                                         bin_lon_step,
                                                         bin_lat_step, 
                                                         bin_time_step,
                                                         output_filename=fn_rmse,
                                                         output_filename_timeseries=fn_ts) 
    
    if compute_spectrum:
        compute_spectral_scores(time_alongtrack, 
                        lat_alongtrack, 
                        lon_alongtrack, 
                        sst_alongtrack, 
                        sst_interp, 
                        lenght_scale,
                        delta_x,
                        delta_t,
                        fn_spectrum)
        leaderboard_psds_score = plot_psd_score(fn_spectrum) 
    else:
        leaderboard_psds_score = -999
        
        
    return leaderboard_nrmse, leaderboard_nrmse_std, int(leaderboard_psds_score)   

### 2.1.1 Non-interactive plots

In [None]:
import matplotlib.ticker as mticker
from cartopy.mpl.ticker import (LongitudeFormatter, LatitudeFormatter,
                                LatitudeLocator)

def config_ticks(ax,crs=ccrs.PlateCarree()):
    ax.add_feature(cfeature.LAND.with_scale('10m'), edgecolor='k', facecolor='white')
    gl = ax.gridlines(crs=crs, draw_labels=True,
                      linewidth=2, color='gray', alpha=0.5, linestyle='-')
    gl.top_labels = False
    gl.right_labels = False
    gl.xlines = True
   
    gl.ylocator = LatitudeLocator()
    gl.xformatter = LongitudeFormatter()
    gl.yformatter = LatitudeFormatter()
   
    gl.xlocator = mticker.MaxNLocator(6)
    gl.ylocator = mticker.MaxNLocator(6)
   
    gl.xlabel_style = {'size': 20, 'color': 'k', 'rotation':45}
    gl.ylabel_style = {'size': 20, 'color': 'k', 'rotation':45}
    
def itrcp_spa(resfile, ds, lon, lat, methods, figsize,
              crs=ccrs.Orthographic(0,45)):
    mesh_lat, mesh_lon = np.meshgrid(lat, lon)
    mesh_lat = mesh_lat.T
    mesh_lon = mesh_lon.T
    
    crs._threshold /= 100.

    extent = [np.min(lon),np.max(lon),np.min(lat),np.max(lat)]

    fig = plt.figure(figsize=figsize)
    nr = int(np.ceil(len(ds)/2))
    nc = int(np.min((2,len(ds))))   
    gs = gridspec.GridSpec(nr, nc)
    #gs.update(wspace=0.05,hspace=0.05)

    vmax = max([ds[i].max() for i in range(len(ds)) ])
    vmin = 0.
    cmap_rmse = plt.cm.viridis
    norm_rmse = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    for i in range(len(ds)):
        ir = int(np.floor(i/2.))
        ic = np.mod(i,2)
        ax = fig.add_subplot(gs[ir,ic], projection=crs, transform=ccrs.PlateCarree())
        plot(ax,lon,lat,ds[i].values,'',extent=extent,cmap=cmap_rmse,
             norm=norm_rmse,colorbar=False, fmt=False)
        config_ticks(ax)
        ax.set_title(methods[i],fontsize=25)
    # Colorbar
    cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.02])
    cbar_ax.tick_params(labelsize=20)
    sm = plt.cm.ScalarMappable(cmap=cmap_rmse, norm=norm_rmse)
    sm._A = []
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', pad=3.0)
    cbar.ax.set_title("[K]",fontsize=22,y=-2)

    plt.savefig(resfile,bbox_inches='tight')
    fig = plt.gcf()
    plt.close()                       
    return fig

def itrcp_spa_diff(resfile, ds_baseline, ds, lon, lat,  
                   methods, figsize, crs=ccrs.Orthographic(0,45)):

    mesh_lat, mesh_lon = np.meshgrid(lat, lon)
    mesh_lat = mesh_lat.T
    mesh_lon = mesh_lon.T
    
    crs._threshold /= 100.

    extent = [np.min(lon),np.max(lon),np.min(lat),np.max(lat)]

    fig = plt.figure(figsize=figsize)
    nr = int(np.ceil(len(ds)/2))
    nc = int(np.min((2,len(ds))))   
    gs = gridspec.GridSpec(nr, nc)
    #gs.update(wspace=0.05,hspace=0.05)

    if not isinstance(ds_baseline,list):
        delta_rmse = [100*(ds[i] - ds_baseline)/ds_baseline for i in range(len(ds))]
    else:
        delta_rmse = [100*(ds[i] - ds_baseline[i])/ds_baseline[i] for i in range(len(ds))]
    vmax = np.nanmax(delta_rmse)
    vmax=20
    vmin = -20.
    cmap_rmse = plt.cm.seismic
    norm_rmse = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    for i in range(len(ds)):
        ir = int(np.floor(i/2.))
        ic = np.mod(i,2)
        ax = fig.add_subplot(gs[ir,ic], projection=crs, transform=ccrs.PlateCarree())
        plot(ax,lon,lat,delta_rmse[i].values,'',extent=extent,cmap=cmap_rmse,
             norm=norm_rmse,colorbar=False,fmt=False)  
        config_ticks(ax)
        ax.set_title(methods[i],fontsize=25)    
    #plt.subplots_adjust(hspace=0.3)
    # Colorbar
    cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.02])
    cbar_ax.tick_params(labelsize=20)
    sm = plt.cm.ScalarMappable(cmap=cmap_rmse, norm=norm_rmse)
    sm._A = []
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', pad=3.0)
    cbar.ax.set_title("[%]",fontsize=22,y=-2)

    plt.savefig(resfile,bbox_inches='tight')
    fig = plt.gcf()
    plt.close()                       
    return fig

def itrcp_ts(resfile, filenames, methods, colors):
    from pathlib import Path
    ds1 = [xr.open_dataset(file, group='diff') for file in filenames]
    ds2 = [xr.open_dataset(file, group='alongtrack') for file in filenames]
    #rmse_score = [ 1. - ds1[i]['rms']/ds2[i]['rms'] for i in range(len(filenames)) ]
    rmse_score = [ ds1[i]['rms'] for i in range(len(filenames)) ]
    rmse_score = [ rmse_score[i].dropna(dim='time').where(ds1[i]['count'] > 10, drop=True) for i in range(len(filenames)) ]
 
    rmse_score=xr.merge([rmse_score[i].to_dataset().rename({"rms":"rms_"+methods[i]}) for i in range(len(filenames)) ])
    plot1 = rmse_score.hvplot.line(x='time',y=['rms_'+methods[i] for  i in range(len(filenames)) ],ylabel='RMSE SCORE', shared_axes=True, color=colors[:len(filenames)])
    plot1.opts(legend_position='top',legend_cols=3)
    plot2 = ds1[0]['count'].dropna(dim='time').hvplot.step(ylabel='#Obs.', shared_axes=True, color='grey')
    figure = (plot1+plot2).cols(1)
    #hvplot.save(figure,resfile).unlink(missing_ok=True)
    return figure
   
def itrcp_spectrum(resfile, filenames, methods, colors):
   
    ds = [xr.open_dataset(file) for file in filenames]
    resolved_scales = [find_wavelength_05_crossing(file) for file in filenames]
       
    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(10, 5))
   
    ax1.invert_xaxis()
    ax1.plot((1./ds[0].wavenumber), ds[0].psd_ref, label='reference', color='k')
    for i in range(len(filenames)):
        ax1.plot((1./ds[i].wavenumber), ds[i].psd_study, label='reconstruction_'+methods[i], color=colors[i])
    ax1.set_xlabel('wavelength [km]')
    ax1.set_ylabel('Power Spectral Density [m$^{2}$/cy/km]')
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    ax1.grid(which='both')
   
    ax2.invert_xaxis()   
    for i in range(len(filenames)):
        ax2.plot((1./ds[i].wavenumber), (1. - ds[i].psd_diff/ds[i].psd_ref), color=colors[i], lw=2)
    ax2.set_xlabel('wavelength [km]')
    ax2.set_ylabel('PSD Score [1. - PSD$_{err}$/PSD$_{ref}$]')
    ax2.set_xscale('log')
    ax2.hlines(y=0.5,
              xmin=np.ma.min(np.ma.masked_invalid(1./ds[0].wavenumber)),
              xmax=np.ma.max(np.ma.masked_invalid(1./ds[0].wavenumber)),
              color='r',
              lw=0.5,
              ls='--')
    imax = np.argmin(resolved_scales)
    ax2.vlines(x=resolved_scales[i], ymin=0, ymax=1, lw=0.5, color=colors[imax])
    ax2.fill_betweenx((1. - ds[imax].psd_diff/ds[imax].psd_ref),
                     resolved_scales[imax],
                     np.ma.max(np.ma.masked_invalid(1./ds[imax].wavenumber)),
                     color=colors[imax],
                     alpha=0.3,
                     label=f'Best resolved scales \n $\lambda$ > {int(resolved_scales[imax])}km')
    ax2.legend(loc='best')
    ax2.grid(which='both')

    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels,  loc='upper left',fontsize=12,frameon=False,
               bbox_to_anchor=(0,1,1,0.2),ncol=2,mode="expand")  
   
    logging.info(' ')
    logging.info(f'  Minimum spatial scale resolved = {int(resolved_scales[imax])}km')
   
    plt.savefig(resfile,bbox_inches="tight")
   
    return None#,resolved_scales


### 2.1.2 Interactive plots

In [None]:
def intercomparison_temporal_statistics(list_of_filename, list_of_label):
   
    ds_diff = xr.concat([xr.open_dataset(filename, group='diff') for filename in list_of_filename], dim='experiment')
    ds_diff['experiment'] = list_of_label
    ds_alongtrack = xr.concat([xr.open_dataset(filename, group='alongtrack') for filename in list_of_filename], dim='experiment')
    ds_alongtrack['experiment'] = list_of_label
   
    rmse_score = 1. - ds_diff['rms']/ds_alongtrack['rms']
    
    rmse_score = rmse_score.dropna(dim='time').where(ds_diff['count'] > 10, drop=True)
   
    figure = rmse_score.hvplot.line(x='time', y='rms', by='experiment', ylim=(0, 1), title='RMSE SCORE', shared_axes=True) + ds_diff['count'][0, :].dropna(dim='time').hvplot.step(ylabel='#Obs.', shared_axes=True, color='grey')
   
    return figure.cols(1)

import cartopy
import cartopy.crs as ccrs

def intercomparison_spatial_statistics(baseline_filename, list_of_filename, list_of_label,var='rmse'):
     
    ds_baseline = xr.open_dataset(baseline_filename, group='diff')
    ds = xr.concat([xr.open_dataset(filename, group='diff') for filename in list_of_filename], dim='experiment')
    ds['experiment'] = list_of_label
   
    if var=='rmse':
        score = 100*(ds - ds_baseline)/ds_baseline
        clim_=(-20, 20)
    else:
        score = ds
        clim_=(0, 1000)
        
    score = score.assign_coords(lon=np.mod(score.lon -180.0,360.0)-180.0)
    score = score.sortby(score.lon)
        
    figure = score[var].hvplot.image(x='lon', y='lat', z=var, clim=clim_, by='experiment',
                                             subplots=True, clabel='[%]', cmap='seismic')
                                             #projection=ccrs.PlateCarree(), coastline=True)
   
    return figure.cols(2)

## 2.2 Domain #1

### 2.2.1 Yearly analysis

In [None]:
#file_dl = "/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_all_baltic_wcoarse_wgeo_linweight_dt7_all.nc"
file_dl = "/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_dt7_linweights_wcoarse_baltic_ext.nc"
file_oi = "/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc"

bin_lat_step = 0.1
bin_lon_step = 0.1
bin_time_step = '1D'
lon_min = np.min(xr.open_dataset(file_dl).lon.data)
lon_max = np.max(xr.open_dataset(file_dl).lon.data)
lat_min = np.min(xr.open_dataset(file_dl).lat.data)
lat_max = np.max(xr.open_dataset(file_dl).lat.data)
time_min = '2021-01-15'
time_max = '2021-12-31'
sel_time = slice(time_min, time_max)
delta_t = 0.9434  # s
velocity = 6.77   # km/s
delta_x = velocity * delta_t
delta_x = 2.2
lenght_scale = 1000 # km

leaderboard_nrmse_dl, leaderboard_nrmse_std_dl, leaderboard_psds_score_dl = run_comparison(file_dl,"4DVarNet")
leaderboard_nrmse_oi, leaderboard_nrmse_std_oi, leaderboard_psds_score_oi = run_comparison(file_oi,"OI",crop=True)

In [None]:
spa_filenames = ['/DATASET/mbeauchamp/spa_stat_OI.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet.nc']
TS_filenames = ['/DATASET/mbeauchamp/TS_OI.nc','/DATASET/mbeauchamp/TS_4DVarNet.nc']
spectrum_filenames = ['/DATASET/mbeauchamp/spectrum_OI.nc','/DATASET/mbeauchamp/spectrum_4DVarNet.nc']
methods = ['OI','4DVarNet']

In [None]:
intercomparison_temporal_statistics(TS_filenames,methods)#,score=False)

In [None]:
intercomparison_spatial_statistics(spa_filenames[0], spa_filenames[1:], ["4DVarNet"])

In [None]:
ds = [ xr.open_dataset(spa_filenames[i],group='diff').rmse for i in range(len(spa_filenames))] 
for i in range(len(ds)):
    ds[i] = ds[i].assign_coords(lon=np.mod(ds[i].lon -180.0,360.0)-180.0)
    ds[i] = ds[i].sortby(ds[i].lon).sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))
itrcp_spa('Benchmark_spa_dm1.png', ds,
           ds[0].lon, ds[0].lat, methods,
           figsize=(20,10*len(methods)/2))

itrcp_spa_diff('Benchmark_spa_diff_dm1.png', ds[0], ds[1:],
           ds[0].lon, ds[0].lat, methods[1:],
           figsize=(20,10*len(methods)/2))

itrcp_spectrum('Benchmark_spectrum_dm1.png',
                               spectrum_filenames,
                               methods,['red','blue'])

itrcp_ts('Benchmark_TS_dm1.html',
                         TS_filenames,
                         methods,['red','blue'])

In [None]:
from PIL import Image
from IPython.display import display
img = Image.open('Benchmark_spa_dm1.png')
display(img)
img = Image.open('Benchmark_spa_diff_dm1.png')
display(img)

In [None]:
#img = Image.open('Benchmark_TS_dm1.png')
#display(img)
img = Image.open('Benchmark_spectrum_dm1.png')
display(img)

In [None]:
data = np.array([['OI',leaderboard_nrmse_oi,leaderboard_nrmse_std_oi,int(leaderboard_psds_score_oi)],
                 ['4DVarNet',leaderboard_nrmse_dl,leaderboard_nrmse_std_dl,int(leaderboard_psds_score_dl)]])
Leaderboard = pd.DataFrame(data, 
                           columns=['Method', 
                                    "µ(RMSE) ", 
                                    "σ(RMSE)", 
                                    'λx (km)'])
print("Summary of the leaderboard metrics:")
Leaderboard
print(Leaderboard.to_markdown()) 

In [None]:
os.remove('/DATASET/mbeauchamp/spa_stat_OI.nc')
os.remove('/DATASET/mbeauchamp/TS_OI.nc')  
os.remove('/DATASET/mbeauchamp/spectrum_OI.nc')
os.remove('/DATASET/mbeauchamp/spa_stat_4DVarNet.nc')
os.remove('/DATASET/mbeauchamp/TS_4DVarNet.nc')  
os.remove('/DATASET/mbeauchamp/spectrum_4DVarNet.nc')

### 2.2.2 With filtering

In [None]:
from py_eddy_tracker import data
from py_eddy_tracker.dataset.grid import RegularGridDataset
from py_eddy_tracker.observations.observation import EddiesObservations
import xarray as xr
import numpy as np

mask_land = np.isfinite(xr.open_dataset('/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc').analysed_sst[0])
def filter_dataset(file,new_file,mask_land=mask_land,wavelength=50,var='analysed_sst'):
    data_all = xr.open_dataset(file) 
    time = data_all.time.values
    lon = data_all.lon.values
    lat = data_all.lat.values
    filtered = []
    for i in range(len(data_all[var])):
        print(i)
        data = np.ma.array(data_all[var][i].T,
                           mask=np.isnan(data_all[var][i].data).T)
        g = RegularGridDataset.with_array(
             coordinates=('lon', 'lat'),
            datas=dict(
                  lon=lon,
                  lat=lat,
                  pred=data),
            variables_description={'pred':{'units':'K'}})
        g.bessel_high_filter("pred", wavelength, order=3)
        filtered.append(g.grid("pred").data.T)    
    # stack all days
    filtered = np.stack(filtered)  
    nc = xr.Dataset(data_vars={var:(('time','lat','lon'),filtered)},
                  coords={'time':time,
                          'lon':lon,
                          'lat':lat})
    nc.coords['mask'] = (('lat', 'lon'), mask_land.values)
    nc.to_netcdf(new_file)

# filter 4DVarNet
filter_dataset('/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_all_baltic_wcoarse_wgeo_linweight_dt7_all.nc',
               '/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_all_baltic_wcoarse_wgeo_linweight_dt7_all_flt10.nc',
                wavelength=10)

# filter OI
filter_dataset('/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc',
               '/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation_flt10.nc',
                wavelength=10)

# filter Obs
filter_dataset('/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L3S_GHRSST-SSTsubskin-night_SST_UHR_NRT-NSEABALTIC_validation_s1.nc',
               '/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L3S_GHRSST-SSTsubskin-night_SST_UHR_NRT-NSEABALTIC_validation_s1_flt10.nc',
               wavelength=10,var="sea_surface_temperature")

In [None]:
from pathlib import Path

#file_dl = "/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_all_baltic_wcoarse_wgeo_linweight_dt7_all.nc"
file_dl = "/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_dt7_linweights_wcoarse_baltic_ext.nc"
file_oi = "/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc"
file_obs = '/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L3S_GHRSST-SSTsubskin-night_SST_UHR_NRT-NSEABALTIC_validation_s1_flt10.nc'

bin_lat_step = 0.1
bin_lon_step = 0.1
bin_time_step = '1D'
lon_min = np.min(xr.open_dataset(file_dl).lon.data)
lon_max = np.max(xr.open_dataset(file_dl).lon.data)
lat_min = np.min(xr.open_dataset(file_dl).lat.data)
lat_max = np.max(xr.open_dataset(file_dl).lat.data)
time_min = '2021-01-15'
time_max = '2021-12-31'
sel_time = slice(time_min, time_max)
delta_t = 0.9434  # s
velocity = 6.77   # km/s
delta_x = velocity * delta_t
delta_x = 2.2
lenght_scale = 1000 # km

leaderboard_nrmse_dl, leaderboard_nrmse_std_dl, leaderboard_psds_score_dl = run_comparison(str(Path(file_dl).with_suffix(''))+'_flt10.nc',
                                                                                           "4DVarNet",
                                                                                            file_obs=file_obs)
leaderboard_nrmse_oi, leaderboard_nrmse_std_oi, leaderboard_psds_score_oi = run_comparison(str(Path(file_oi).with_suffix(''))+'_flt10.nc',
                                                                                           "OI",crop=True,
                                                                                           file_obs=file_obs)

In [None]:
spa_filenames = ['/DATASET/mbeauchamp/spa_stat_OI.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet.nc']
TS_filenames = ['/DATASET/mbeauchamp/TS_OI.nc','/DATASET/mbeauchamp/TS_4DVarNet.nc']
spectrum_filenames = ['/DATASET/mbeauchamp/spectrum_OI.nc','/DATASET/mbeauchamp/spectrum_4DVarNet.nc']
methods = ['OI','4DVarNet']
intercomparison_spatial_statistics(spa_filenames[0], spa_filenames[1:], ["4DVarNet"])

### 2.2.3 Monthly analysis

In [None]:
sel_time = ["DJF"]
leaderboard_nrmse_oi_DJF, leaderboard_nrmse_std_oi_DJF, leaderboard_psds_score_oi_DJF = run_comparison(file_oi,"OI_DJF",crop=True,byseason=True)
leaderboard_nrmse_dl_DJF, leaderboard_nrmse_std_dl_DJF, leaderboard_psds_score_dl_DJF = run_comparison(file_dl,"4DVarNet_DJF",byseason=True)
sel_time = ["MAM"]
leaderboard_nrmse_oi_MAM, leaderboard_nrmse_std_oi_MAM, leaderboard_psds_score_oi_MAM = run_comparison(file_oi,"OI_MAM",crop=True,byseason=True)
leaderboard_nrmse_dl_MAM, leaderboard_nrmse_std_dl_MAM, leaderboard_psds_score_dl_MAM = run_comparison(file_dl,"4DVarNet_MAM",byseason=True)
sel_time = ["JJA"]
leaderboard_nrmse_oi_JJA, leaderboard_nrmse_std_oi_JJA, leaderboard_psds_score_oi_JJA = run_comparison(file_oi,"OI_JJA",crop=True,byseason=True)
leaderboard_nrmse_dl_JJA, leaderboard_nrmse_std_dl_JJA, leaderboard_psds_score_dl_JJA = run_comparison(file_dl,"4DVarNet_JJA",byseason=True)
sel_time = ["SON"]
leaderboard_nrmse_oi_SON, leaderboard_nrmse_std_oi_SON, leaderboard_psds_score_oi_SON = run_comparison(file_oi,"OI_SON",crop=True,byseason=True)
leaderboard_nrmse_dl_SON, leaderboard_nrmse_std_dl_SON, leaderboard_psds_score_dl_SON = run_comparison(file_dl,"4DVarNet_SON",byseason=True)

In [None]:
spa_filenames = ['/DATASET/mbeauchamp/spa_stat_OI_DJF.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_MAM.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_JJA.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_SON.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_SON.nc']
TS_filenames = ['/DATASET/mbeauchamp/TS_OI_DJF.nc','/DATASET/mbeauchamp/TS_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/TS_OI_MAM.nc','/DATASET/mbeauchamp/TS_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/TS_OI_JJA.nc','/DATASET/mbeauchamp/TS_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/TS_OI_SON.nc','/DATASET/mbeauchamp/TS_4DVarNet_SON.nc']
spectrum_filenames = ['/DATASET/mbeauchamp/spectrum_OI_DJF.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_MAM.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_JJA.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_SON.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_SON.nc']
methods = ['OI (DJF)','4DVarNet (DJF)',
           'OI (MAM)','4DVarNet (MAM)',
           'OI (JJA)','4DVarNet (JJA)',
           'OI (SON)','4DVarNet (SON)']
colors = ['blue','violet','green','red','black','orange','gray','yellow']


ds = [ xr.open_dataset(spa_filenames[i],group='diff').rmse for i in range(len(spa_filenames))] 
for i in range(len(ds)):
    ds[i] = ds[i].assign_coords(lon=np.mod(ds[i].lon -180.0,360.0)-180.0)
    ds[i] = ds[i].sortby(ds[i].lon).sel(lon=slice(lon_min,lon_max),lat=slice(lat_min,lat_max))
    
itrcp_spa('Benchmark_spa_byseason_dm1.png', ds,
           ds[0].lon, ds[0].lat, methods,
           figsize=(22,6*len(methods)/2))

itrcp_spa_diff('Benchmark_spa_diff_byseason_dm1.png', [ ds[i] for i in [0,2,4,6] ], 
               [ ds[i] for i in [1,3,5,7] ],
               ds[0].lon, ds[0].lat, [ methods[i] for i in [1,3,5,7] ],
               figsize=(22,6*len(methods)/2))

itrcp_spectrum('Benchmark_spectrum_byseason_dm1.png',
                               spectrum_filenames,
                               methods,colors)

itrcp_ts('Benchmark_TS_byseason_dm1.png',
                         TS_filenames,
                         methods,colors)

In [None]:
from PIL import Image
from IPython.display import display
img = Image.open('Benchmark_spa_byseason_dm1.png')
display(img)
img = Image.open('Benchmark_spa_diff_byseason_dm1.png')
display(img)

In [None]:
data = np.array([['OI (DJF)',leaderboard_nrmse_oi_DJF,leaderboard_nrmse_std_oi_DJF,int(leaderboard_psds_score_oi_DJF)],
                 ['4DVarNet (DJF)',leaderboard_nrmse_dl_DJF,leaderboard_nrmse_std_dl_DJF,int(leaderboard_psds_score_dl_DJF)],
                 ['OI (MAM)',leaderboard_nrmse_oi_MAM,leaderboard_nrmse_std_oi_MAM,int(leaderboard_psds_score_oi_MAM)],
                 ['4DVarNet (MAM)',leaderboard_nrmse_dl_MAM,leaderboard_nrmse_std_dl_MAM,int(leaderboard_psds_score_dl_MAM)],
                 ['OI (JJA)',leaderboard_nrmse_oi_JJA,leaderboard_nrmse_std_oi_JJA,int(leaderboard_psds_score_oi_JJA)],
                 ['4DVarNet (JJA)',leaderboard_nrmse_dl_JJA,leaderboard_nrmse_std_dl_JJA,int(leaderboard_psds_score_dl_JJA)],
                 ['OI (SON)',leaderboard_nrmse_oi_SON,leaderboard_nrmse_std_oi_SON,int(leaderboard_psds_score_oi_SON)],
                 ['4DVarNet (SON)',leaderboard_nrmse_dl_SON,leaderboard_nrmse_std_dl_SON,int(leaderboard_psds_score_dl_SON)]])
Leaderboard_season = pd.DataFrame(data, 
                           columns=['Method', 
                                    "µ(RMSE) ", 
                                    "σ(RMSE)", 
                                    'λx (km)'])
print("Summary of the seasonal leaderboard metrics:")
Leaderboard_season
print(Leaderboard_season.to_markdown()) 

In [None]:
for season in ['DJF','MAM','JJA','SON']:
    os.remove('/DATASET/mbeauchamp/spa_stat_OI_'+season+'.nc')
    os.remove('/DATASET/mbeauchamp/TS_OI_'+season+'.nc')  
    os.remove('/DATASET/mbeauchamp/spectrum_OI_'+season+'.nc')
    os.remove('/DATASET/mbeauchamp/spa_stat_4DVarNet_'+season+'.nc')
    os.remove('/DATASET/mbeauchamp/TS_4DVarNet_'+season+'.nc')  
    os.remove('/DATASET/mbeauchamp/spectrum_4DVarNet_'+season+'.nc')

## 2.2 Domain #2

In [None]:
file_dl = "/DATASET/mbeauchamp/DMI/4DVarNet_outputs/DMI-L4_GHRSST-SSTfnd-DMI_4DVarNet-NSEABALTIC_2021_baltic_dm2.nc"
file_oi = "/DATASET/mbeauchamp/DMI/validation_dataset/DMI-L4_GHRSST-SSTfnd-DMI_OI-NSEABALTIC_2021_validation.nc"

bin_lat_step = 0.1
bin_lon_step = 0.1
bin_time_step = '1D'
lon_min = np.min(xr.open_dataset(file_dl).lon.data)
lon_max = np.max(xr.open_dataset(file_dl).lon.data)
lat_min = np.min(xr.open_dataset(file_dl).lat.data)
lat_max = np.max(xr.open_dataset(file_dl).lat.data)
time_min = '2021-01-15'
time_max = '2021-12-31'
sel_time = slice(time_min, time_max)
delta_t = 0.9434  # s
velocity = 6.77   # km/s
delta_x = velocity * delta_t
delta_x = 2.2
lenght_scale = 1000 # km

leaderboard_nrmse_dl, leaderboard_nrmse_std_dl, leaderboard_psds_score_dl = run_comparison(file_dl,"4DVarNet")
leaderboard_nrmse_oi, leaderboard_nrmse_std_oi, leaderboard_psds_score_oi = run_comparison(file_oi,"OI",crop=True)

In [None]:
spa_filenames = ['/DATASET/mbeauchamp/spa_stat_OI.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet.nc']
TS_filenames = ['/DATASET/mbeauchamp/TS_OI.nc','/DATASET/mbeauchamp/TS_4DVarNet.nc']
spectrum_filenames = ['/DATASET/mbeauchamp/spectrum_OI.nc','/DATASET/mbeauchamp/spectrum_4DVarNet.nc']
methods = ['OI','4DVarNet']

In [None]:
intercomparison_temporal_statistics(TS_filenames,methods)

In [None]:
intercomparison_spatial_statistics(spa_filenames[0], spa_filenames[1:], ["4DVarNet"])

In [None]:
sel_time = ["DJF"]
leaderboard_nrmse_oi_DJF, leaderboard_nrmse_std_oi_DJF, leaderboard_psds_score_oi_DJF = run_comparison(file_oi,"OI_DJF",crop=True,byseason=True)
leaderboard_nrmse_dl_DJF, leaderboard_nrmse_std_dl_DJF, leaderboard_psds_score_dl_DJF = run_comparison(file_dl,"4DVarNet_DJF",byseason=True)
sel_time = ["MAM"]
leaderboard_nrmse_oi_MAM, leaderboard_nrmse_std_oi_MAM, leaderboard_psds_score_oi_MAM = run_comparison(file_oi,"OI_MAM",crop=True,byseason=True)
leaderboard_nrmse_dl_MAM, leaderboard_nrmse_std_dl_MAM, leaderboard_psds_score_dl_MAM = run_comparison(file_dl,"4DVarNet_MAM",byseason=True)
sel_time = ["JJA"]
leaderboard_nrmse_oi_JJA, leaderboard_nrmse_std_oi_JJA, leaderboard_psds_score_oi_JJA = run_comparison(file_oi,"OI_JJA",crop=True,byseason=True)
leaderboard_nrmse_dl_JJA, leaderboard_nrmse_std_dl_JJA, leaderboard_psds_score_dl_JJA = run_comparison(file_dl,"4DVarNet_JJA",byseason=True)
sel_time = ["SON"]
leaderboard_nrmse_oi_SON, leaderboard_nrmse_std_oi_SON, leaderboard_psds_score_oi_SON = run_comparison(file_oi,"OI_SON",crop=True,byseason=True)
leaderboard_nrmse_dl_SON, leaderboard_nrmse_std_dl_SON, leaderboard_psds_score_dl_SON = run_comparison(file_dl,"4DVarNet_SON",byseason=True)

In [None]:
spa_filenames = ['/DATASET/mbeauchamp/spa_stat_OI_DJF.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_MAM.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_JJA.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/spa_stat_OI_SON.nc','/DATASET/mbeauchamp/spa_stat_4DVarNet_SON.nc']
TS_filenames = ['/DATASET/mbeauchamp/TS_OI_DJF.nc','/DATASET/mbeauchamp/TS_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/TS_OI_MAM.nc','/DATASET/mbeauchamp/TS_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/TS_OI_JJA.nc','/DATASET/mbeauchamp/TS_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/TS_OI_SON.nc','/DATASET/mbeauchamp/TS_4DVarNet_SON.nc']
spectrum_filenames = ['/DATASET/mbeauchamp/spectrum_OI_DJF.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_DJF.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_MAM.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_MAM.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_JJA.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_JJA.nc',
                 '/DATASET/mbeauchamp/spectrum_OI_SON.nc','/DATASET/mbeauchamp/spectrum_4DVarNet_SON.nc']
methods = ['OI (DJF)','4DVarNet (DJF)',
           'OI (MAM)','4DVarNet (MAM)',
           'OI (JJA)','4DVarNet (JJA)',
           'OI (SON)','4DVarNet (SON)']
colors = ['blue','violet','green','red','black','orange','gray','yellow']

ds = [ xr.open_dataset(spa_filenames[i],group='diff').rmse.sel(lon=slice(lon_min,lon_max),
                                                               lat=slice(lat_min,lat_max)) for i in range(len(spa_filenames))] 

itrcp_spa('Benchmark_spa_byseason_dm1.png', ds,
           ds[0].lon, ds[0].lat, methods,
           figsize=(20,10*len(methods)/2),
           crs=ccrs.PlateCarree())

itrcp_spa_diff('Benchmark_spa_diff_byseason_dm1.png', [ ds[i] for i in [0,2,4,6] ], 
               [ ds[i] for i in [1,3,5,7] ],
               ds[0].lon, ds[0].lat, [ methods[i] for i in [1,3,5,7] ],
               figsize=(20,10*len(methods)/2),
               crs=ccrs.PlateCarree())


"""
itrcp_ts('Benchmark_TS_byseason_dm1.png',
                         TS_filenames,
                         methods,colors)
"""
itrcp_spectrum('Benchmark_spectrum_byseason_dm1.png',
                               spectrum_filenames,
                               methods,colors)

In [None]:
from PIL import Image
from IPython.display import display
img = Image.open('Benchmark_spa_byseason_dm1.png')
display(img)
img = Image.open('Benchmark_spa_diff_byseason_dm1.png')
display(img)

In [None]:
data = np.array([['OI (DJF)',leaderboard_nrmse_oi_DJF,leaderboard_nrmse_std_oi_DJF,int(leaderboard_psds_score_oi_DJF)],
                 ['4DVarNet (DJF)',leaderboard_nrmse_dl_DJF,leaderboard_nrmse_std_dl_DJF,int(leaderboard_psds_score_dl_DJF)],
                 ['OI (MAM)',leaderboard_nrmse_oi_MAM,leaderboard_nrmse_std_oi_MAM,int(leaderboard_psds_score_oi_MAM)],
                 ['4DVarNet (MAM)',leaderboard_nrmse_dl_MAM,leaderboard_nrmse_std_dl_MAM,int(leaderboard_psds_score_dl_MAM)],
                 ['OI (JJA)',leaderboard_nrmse_oi_JJA,leaderboard_nrmse_std_oi_JJA,int(leaderboard_psds_score_oi_JJA)],
                 ['4DVarNet (JJA)',leaderboard_nrmse_dl_JJA,leaderboard_nrmse_std_dl_JJA,int(leaderboard_psds_score_dl_JJA)],
                 ['OI (SON)',leaderboard_nrmse_oi_SON,leaderboard_nrmse_std_oi_SON,int(leaderboard_psds_score_oi_SON)],
                 ['4DVarNet (SON)',leaderboard_nrmse_dl_SON,leaderboard_nrmse_std_dl_SON,int(leaderboard_psds_score_dl_SON)]])
Leaderboard_season = pd.DataFrame(data, 
                           columns=['Method', 
                                    "µ(RMSE) ", 
                                    "σ(RMSE)", 
                                    'λx (km)'])
print("Summary of the seasonal leaderboard metrics:")
Leaderboard_season
print(Leaderboard_season.to_markdown()) 