# Import Packages

In [None]:
import sys
import os

PATH_with_functions = '/home/ardhuin/TOOLS/OPTOOLS/PYTHON/'
PATH_save0 = '/home/ardhuin/PUBLI/2023_groups/Notebook_retrack/test/'
PATH_read0 = ''
PATH_save = ''
PATH_save2 = ''
sys.path.append(PATH_with_functions)

import glob
import xarray as xr
import numpy as np
import pandas as pd

#import cartopy.crs as ccrs
#import cartopy.feature as cfeature
import scipy.interpolate as spi
import scipy.integrate as spint
from scipy.ndimage import gaussian_filter, correlate
from scipy.signal import hilbert,hilbert2,fftconvolve

from functions_cfosat_env import *
from surface_simulation_functions import *
from altimetry_waveforms_functions import *
from matplotlib.dates import DateFormatter
# --- plotting and interactive stuff ----------------------
import matplotlib.pyplot as plt
# from matplotlib.ticker import AutoMinorLocator, FixedLocator

#from Envelope_convolution_functions import *

cNorm = mcolors.Normalize(vmin=0, vmax=2)
jet = plt.get_cmap('jet')
mpl.rcParams.update({'figure.figsize':[10,6],'axes.grid' : True,'font.size': 14,'savefig.facecolor':'white'})

cmap0 = 'viridis'

# Prepare all

In [None]:
# ------ plot retrack vs 4 * std(zeta) for all radiuses for both ind1 and ind2 -------------
def custom_plots_surf(ax,im,iskm=1,alongT_isY=1,labelcb=None):
    if alongT_isY:
        xlb = 'Cross-track '
        ylb = 'Along-track '
    else:
        xlb = 'X '
        ylb = 'Y '

    if iskm:
        xlb = xlb+'[km]'
        ylb = ylb+'[km]'
    else:
        xlb = xlb+'[m]'
        ylb = ylb+'[m]'       
    
    ax.set_xlabel(xlb);
    ax.set_ylabel(ylb);

    ax.set_aspect('equal', 'box');
    if labelcb is None:
        plt.colorbar(im, ax=ax)
    else:
        plt.colorbar(im, ax=ax,label=labelcb)
    return ax

def custom_plots_spec(ax,im,iswnb=1,alongT_isY=1,klim=None,labelcb=None):
    if iswnb:
        xlb = '$k_x$ [rad/m]'
        ylb = '$k_x$ [rad/m]'
    else:
        xlb = '$k_x / 2 \pi$ [km$^{-1}$]'
        ylb = '$k_y / 2 \pi$ [km$^{-1}$]'
    
    if alongT_isY:
        ax.set_xlabel(ylb)
        ax.set_ylabel(xlb)
    else:
        ax.set_xlabel(xlb)
        ax.set_ylabel(ylb)

    ax.set_aspect('equal', 'box');
    if labelcb is None:
        plt.colorbar(im, ax=ax)
    else:
        plt.colorbar(im, ax=ax,label=labelcb)
        
    if klim is not None:
        ax.set_xlim((-klim,klim))
        ax.set_ylim((-klim,klim))
    return ax

## Read spectrum

In [None]:
DS = xr.open_dataset(PATH_save2+'Spectrum_L2S_ind9_35.nc')
Hs_9 = DS['Hs_ind_9_box'].compute().data
Hs_9_L2S = DS['Hs_ind_9_L2S'].compute().data
Hs_35 = DS['Hs_ind_35_box'].compute().data
Hs_35_L2S = DS['Hs_ind_35_L2S'].compute().data
# Use factor to check different wave heights ... 
Hsfac=1
Lc_9 = calc_footprint_diam(Hs_9)
print('ind 1: 9, Hs=',Hs_9,'m, Diameter Chelton =',Lc_9)
Lc_35 = calc_footprint_diam(Hs_35*Hsfac)
print('ind 2: 35, Hs=',Hs_35,'m, Diameter Chelton =',Lc_35)

## Generate or load sea surface (geometry only)

In [None]:
iscompute = 0
if iscompute:
    Efth = DS['Spec_L2S_ind_35_1Sided'].compute().data*Hs_35**2/Hs_35_L2S**2*Hsfac**2
    th_vec = DS['phi_vector_L2S_ind_35'].compute().data
    f_vec = DS['k_vector'].compute().data
    nx_big = 2**12
    ny_big = nx_big-2
    dx_big=14
    dy_big=14
    seed=0

    S_r_35_big,S_i_35_big,Xa_35_big,Ya_35_big,\
    rg,kX2_35_big,kY2_35_big,Ekxky_35_big,dkx2_35,dky2_35 = surface_from_Efth(Efth,f_vec,th_vec,
                                                                        seed=seed,nx=nx_big,
                                                                        ny=ny_big,dx=dx_big,dy=dy_big,iswvnb=1)

    np.savez(PATH_save2+'surface_good_for_images_ind35_L2S',seed=seed,dx=dx_big,dy=dy_big,X=Xa_35_big,Y=Ya_35_big,S_r =S_r_35_big ,S_i = S_i_35_big)
else:
    data = np.load(PATH_save2+'surface_good_for_images_ind35_L2S.npz',allow_pickle=True)
    dx_big = data['dx']
    dy_big = data['dy']
    Xa_35_big = data['X']
    Ya_35_big = data['Y']
    S_r_35_big = data['S_r']
    S_i_35_big = data['S_i']
    seed = data['seed']
    
B_35_big = np.sqrt(S_r_35_big**2+S_i_35_big**2)
nx_big = len(Xa_35_big)
ny_big = len(Ya_35_big)
Xa_35_big_cent = Xa_35_big - 0.5*Xa_35_big.max()
Ya_35_big_cent = Ya_35_big - 0.5*Ya_35_big.max()

## Define geometry of altimeter footprints

In [None]:
clight = 299792458
Rearth = 6370*1e3
v_sat = 7*1e3 # satellite velocity in m/s

freq_satsampl = 20 # freq for waveforms

nx_s = np.min((len(Xa_35_big),len(Ya_35_big)))

di=np.floor((v_sat/freq_satsampl)/dx_big).astype(int) # distance between footprint centers, in pixels (v_sat/freq_satsampl = dsitance in m)

# --- radius to compute waveforms --------------
nxa0=np.floor(10*1e3/dx_big).astype(int) # size of radius of footprint in pixel
nsamp=np.floor((nx_s-2*nxa0)/di).astype(int) # Nb of samplesnxa_L2 = np.floor(10*1e3/dx_c).astype(int)
#nsamp=10

iasamp = nxa0+np.arange(nsamp,dtype=int)*di
Xalts = Xa_35_big[iasamp]
Yalts = Ya_35_big[iasamp[:-1]]

### Compute waveforms and performs retracking 2D  (Hs, epoch)

In [None]:
nHs = 250
nze = 251
# --- Compute 2 params ----------------------------
Altis = np.array([519*1e3, 781*1e3, 800*1e3, 891*1e3, 1340*1e3])
namesAltis = ['CFOSAT', 'SARAL', 'ENVISAT', 'SWOT', 'Jason3']
BW = np.array([(320*1e6),(500*1e6),(320*1e6),(320*1e6),(320*1e6)])

ialti = 0

iscompute = 1
alti_sat = Altis[ialti]
bandwidth = BW[ialti]
nameSat = namesAltis[ialti]

if iscompute:
    print(nameSat)
    
    DiamChelton = calc_footprint_diam(Hs_35,Rorbit = alti_sat, pulse_width=1/bandwidth)
    
    dr = clight * 1/(2*bandwidth)
    range_offset = 10.0 # 22*dr
    print('offset:',range_offset)
    
    wfm_ref, Hsm_ref, ze_ref, edges_ref, dr = generate_wvform_database_2D(nHs,nze,ne=None,bandwidth=bandwidth,edges_max=35,
                                                          Hs_max=25,offset=range_offset)
        
        
    nxa = np.floor(DiamChelton/dx_big).astype(int) # size of radius of footprint in pixel
    print('size of radius of footprint in pixel : ',nxa)
 
    # --- compute only retracking ---------------------------
#    Hs_retrack_2D,ze_retrack_2D,Xalt,Yalt,_,_ = fly_over_track_only_retrack_2D(Xa_35_big,Ya_35_big,S_r_35_big,\
#                                                        nsamp,nxa0,nxa,di,wfm_ref,Hsm_ref,\
#                                                        edges_ref,ze_ref,range_shift=range_offset,\
#                                                        alti_sat=alti_sat)
    Hs_retrack_2D,ze_retrack_2D,Xalt,Yalt,waveforms,_ = fly_over_track_only_retrack_2D(Xa_35_big,Ya_35_big,S_r_35_big,\
                                                        nsamp,nxa0,nxa,di,wfm_ref,Hsm_ref,\
                                                        edges_ref,ze_ref,range_shift=range_offset,\
                                                        alti_sat=alti_sat)

    np.savez(PATH_save0+'Hs_ze_retrack_good_surface_'+nameSat,alti_sat = alti_sat,name_sat = nameSat,BWs = bandwidth,
             Xalt = Xalt,Yalt = Yalt,waveforms = waveforms, X_surf = Xa_35_big, Y_surf = Ya_35_big,S_r = S_r_35_big,
             S_i = S_i_35_big, Hs_retrack_2D = Hs_retrack_2D, ze_retrack_2D = ze_retrack_2D)
else:
    data=np.load(PATH_read0+'Hs_ze_retrack_good_surface_'+nameSat+'.npz',allow_pickle=True)
    for k in data.keys():
        print(k+' = data["'+k+'"]')
        exec(k+' = data["'+k+'"]')

In [None]:
## Checks retracking on one example
ir=int((Hs_retrack_2D[0,0]-Hsm_ref[0])/(Hsm_ref[1]-Hsm_ref[0]))
jr=int((ze_retrack_2D[0,0]-ze_ref[0])/(ze_ref[1]-ze_ref[0]))
print('Indices in waveform database:',ir,jr)
fig,axs=plt.subplots(1,2,figsize=(12,6))
ax=axs[0]
line1=ax.plot(edges_ref[0:-1]+dr/2,waveforms[0,0,:],color='k',label='waveform')
line2=ax.plot(edges_ref[0:-1]+dr/2,wfm_ref[ir,jr,:],color='r',label='fitted wf')
ax.set_xlabel('range (m)')
ax.set_ylabel('waveform')
leg = ax.legend(loc='upper right')
ax=axs[1]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,Hs_retrack_2D.T)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ retrack 2D [m]')
ax.set_title('$H_s$ retrack 2D [m]')


### Compute waveforms and performs retracking 1D  (Hs only)

In [None]:
Altis = np.array([519*1e3, 781*1e3, 800*1e3, 891*1e3, 1340*1e3])
namesAltis = ['CFOSAT', 'SARAL', 'ENVISAT', 'SWOT', 'Jason3']
BW = np.array([(320*1e6),(500*1e6),(320*1e6),(320*1e6),(320*1e6)])

ialti = 0

iscompute = 1
alti_sat = Altis[ialti]
bandwidth = BW[ialti]
nameSat = namesAltis[ialti]
    
if iscompute:
    print(nameSat)
    
    DiamChelton = calc_footprint_diam(Hs_35,Rorbit = alti_sat, pulse_width=1/bandwidth)
    
    dr = clight * 1/(2*bandwidth)
    range_offset = 22.5*dr
    
    wfm_ref, Hsm_ref, edges_ref,dr = generate_wvform_database(nHs,ne=None,bandwidth=bandwidth,edges_max=35,
                                                          Hs_max=25,offset=range_offset)
        
        
    nxa = np.floor(DiamChelton/dx_big).astype(int) # size of radius of footprint in pixel
    print('size of radius of footprint in pixel : ',nxa)
 
    # --- compute only retracking ---------------------------
    Hs_retrack_1D,Xalt,Yalt,_,_ = fly_over_track_only_retrack(Xa_35_big,Ya_35_big,S_r_35_big,\
                                                        nsamp,nxa0,nxa,di,wfm_ref,Hsm_ref,\
                                                        edges_ref,range_shift=range_offset,\
                                                        alti_sat=alti_sat,isepoch = 0)

#    np.savez(PATH_save0+'/Hs_retrack_good_surface_'+nameSat,alti_sat = alti_sat,names_sat = nameSat,BWs= bandwidth,
#             Xalt = Xalt,Yalt=Yalt,X_surf =Xa_35_big, Y_surf = Ya_35_big,S_r = S_r_35_big,
#             S_i = S_i_35_big, Hs_retrack_1D = Hs_retrack_1D)
else:
    data=np.load(PATH_read0+'Hs_retrack_good_surface_'+nameSat+'.npz',allow_pickle=True)
    for k in data.keys():
        print(k+' = data["'+k+'"]')
        exec(k+' = data["'+k+'"]')

In [None]:
## Checks retracking on one example
ir=int((Hs_retrack_1D[0,0]-Hsm_ref[0])/(Hsm_ref[1]-Hsm_ref[0]))
print('Indices in waveform database:',ir)
fig,axs=plt.subplots(1,2,figsize=(12,6))
ax=axs[0]
line1=ax.plot(edges_ref[0:-1]+dr/2,waveforms[0,0,:],color='k',label='waveform')
line2=ax.plot(edges_ref[0:-1]+dr/2,wfm_ref[ir,:],color='r',label='fitted wf')
ax.set_xlabel('range (m)')
ax.set_ylabel('waveform')
leg = ax.legend(loc='upper right')
ax=axs[1]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,Hs_retrack_1D.T)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ retrack 2D [m]')
ax.set_title('$H_s$ retrack 1D [m]')


## Define filters to get (Hs,epoch) from envelope map

In [None]:
def define_filter_annexA(Xa_c,Ya_c,DiamChelton,nkx_c,nky_c,dx_c,dy_c):
# Uses approximation r0**2/rc**2 = R0/Hs 
    twopi = 2*np.pi
    rc = DiamChelton/2
    [Xa_c2,Ya_c2] = np.meshgrid(Xa_c, Ya_c, indexing='ij')
 
    r0 = np.sqrt((Xa_c2)**2+(Ya_c2)**2)

# Defines a Gaussian filter scaled with rc 
    G_Lc20 = np.exp(-0.5* r0**2 / (rc)**2 )
    G_Lc2 = G_Lc20/(rc**2*twopi)

    Id = np.zeros(np.shape(G_Lc2))
    Id[nkx_c//2,nky_c//2]=1/(dx_c*dy_c)

#  This is the same as Jr0= A / (pi*h*Hs) * J 
    Jr0 = (4*dx_c*dy_c/(np.pi*rc**2)) * (r0/rc)**2 * (6 - ((2*r0/rc)**4)) * np.exp(- 4 * r0**4 / rc**4)
    Jr1 = fftconvolve((Id-G_Lc2),Jr0,mode='same')
    
    Filter_new = (G_Lc2+Jr1)
    
    phi_x0 = xr.DataArray(Filter_new,
        dims=['x','y'],
        coords={
            "x" : Xa_c,
            "y" : Ya_c,
            },
        )
    return phi_x0

In [None]:
#Functions (filter J2 + custom plots)
def define_filter_J2_annexA(Xa_c,Ya_c,DiamChelton,nkx_c,nky_c,dx_c,dy_c,isplot=0):
    twopi = 2*np.pi
    rc = DiamChelton/2
    [Xa_c2,Ya_c2] = np.meshgrid(Xa_c, Ya_c, indexing='ij')
    r0 = np.sqrt((Xa_c2)**2+(Ya_c2)**2)
# Defines a Gaussian filter scaled with rc 
    G_Lc20 = np.exp(-0.5* r0**2 / (rc)**2 )
    G_Lc2 = G_Lc20/(rc**2*twopi)
    Id = np.zeros(np.shape(G_Lc2))
    Id[nkx_c//2,nky_c//2]=1/(dx_c*dy_c)
#     plt.plot(Xa_c,Id[:,nky_c//2]-G_Lc2[:,nky_c//2])
# Uses approximation r0**2/rc**2 = R0/Hs 
#  This is the same as J200= -A / (4*2*pi*h*Hs) * J2
    J200 = -(dx_c*dy_c/(4*np.pi*rc**2)) * (2 - 16*((r0/rc)**4)) * np.exp(- 4 * r0**4 / rc**4)
    Jr2 = fftconvolve((Id-G_Lc2),J200,mode='same')
    Filter_new = (Jr2)
    #print('Sum:',np.sum(np.abs(Filter_new)*dx_c*dy_c),dx_c*dy_c)
    #Filter_new = -Filter_new/np.sum(np.abs(Filter_new)*dx_c*dy_c)
    if isplot:
        plt.figure()
        plt.plot(Xa_c,G_Lc2[:,nky_c//2],label='G_{Lc}')
        plt.plot(Xa_c,J20[:,nky_c//2],label='J20')
        plt.plot(Xa_c,Jr2[:,nky_c//2],label='Jr2')
        plt.grid(True)
        plt.legend()
    phi_x0 = xr.DataArray(Filter_new,
                dims=['x','y'],
                coords={
                    "x" : Xa_c,
                    "y" : Ya_c,
                    },
                )
    return phi_x0

## Compute filters

In [None]:
phi_J_AnnexA = define_filter_annexA(Xa_35_big_cent,Ya_35_big_cent,Lc_35,nx_big,ny_big,dx_big,dy_big)

phi_J2_A = define_filter_J2_annexA(Xa_35_big_cent,Ya_35_big_cent,Lc_35,nx_big,ny_big,dx_big,dy_big)

## Apply filters

In [None]:
# coeff transforms the envelope to Hs map 
coeff=4*np.sqrt(2/np.pi)

# -- fftconvolve J and J2 from Annex A to env ------
B11 = fftconvolve(B_35_big,phi_J_AnnexA,mode='same')*dx_big*dy_big
B12 = B11[iasamp[:-1]][:,iasamp].T
equiv_Hs = coeff * B12

B31 = fftconvolve(B_35_big,phi_J2_A,mode='same')*dx_big*dy_big
B32 = B31[iasamp[:-1]][:,iasamp].T
equiv_ze = coeff * B32


# PLOTS

## Compare 1D vs 2D

In [None]:
print(np.mean(Hs_retrack_2D),np.mean(Hs_retrack_1D))
fig,axs=plt.subplots(1,3,figsize=(21,6))
vmin=None
vmax=None

ax=axs[0]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,Hs_retrack_2D.T,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ retrack 2D [m]')
[vmin,vmax]=im.get_clim()
ax.set_title('$H_s$ retrack 2D [m]')

ax=axs[1]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,Hs_retrack_1D.T,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ retrack 1D [m]')
ax.set_title('$H_s$ retrack 1D [m]')

ax=axs[2]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,(Hs_retrack_1D-Hs_retrack_2D).T)#,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='difference [m]')
ax.set_title('$H_s$ retrack 1D - 2D [m]')

plt.tight_layout()

## Compare 2D vs equivalent

In [None]:
fig,axs=plt.subplots(2,2,figsize=(21,16))
vmin=None
vmax=None
ax=axs[0,0]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,Hs_retrack_2D.T,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ retrack 2D [m]')
[vmin,vmax]=im.get_clim()

ax=axs[0,1]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,equiv_Hs.T,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$H_s$ from smooth [m]')

ax=axs[1,0]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,ze_retrack_2D.T)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$z_e$ retrack 2D [m]')
[vmin,vmax]=im.get_clim()

ax=axs[1,1]
im=ax.pcolormesh(Xalts/1e3,Yalts/1e3,equiv_ze.T,vmin=vmin,vmax=vmax)
ax=custom_plots_surf(ax,im,iskm=1,alongT_isY=0,labelcb='$z_e$ from smooth [m]')

plt.tight_layout()

In [None]:
ratio_ze = np.abs(ze_retrack_2D-(equiv_ze))/np.abs(equiv_ze)

ratio_ze = (ze_retrack_2D+1)/(equiv_ze +1)
# plt.hist(ratio_ze.flatten(),bins=np.linspace(0,5,200));

plt.plot(ze_retrack_2D,equiv_ze,'.k');
plt.xlabel('ze retrack')
plt.ylabel('ze from smooth')
plt.axline([0,0],slope=1,color='k',linestyle='--')
# plt.axline([0,0],slope=2,color='r')
# plt.axline([0,0],slope=1.5,color='b')

from scipy import stats
res = stats.linregress(ze_retrack_2D.flatten(),equiv_ze.flatten())
rval = res.rvalue
rinter = res.intercept
rslope = res.slope

plt.axline([0,rinter],slope=rslope,color='b',lw=3)
print(res)

In [None]:
print('std of ze is :',np.std(ze_retrack_2D))