# Look at along-swath PSD using xrft

In [1]:
import matplotlib as mpl
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import box, LineString, Point
import os,argparse
import cmocean as cm
import numpy as np

# Add the path to the Tatsu's swot library
import sys
sys.path.append('../src/')
import swot_utils
import data_loaders
import download_swaths
import plotting_scripts


#turn off warnings
import warnings
warnings.filterwarnings("ignore")


In [2]:
# Specify the path to the subsetted data
L3_kuroshio_path = f"../SWOT_L3/Unsmoothed_kuroshio"

# Define domain
# Rough East of Japan domain (Kuroshio-ish)
kuroshio_sw_corner = [140,15]
kuroshio_ne_corner = [170,40]
lat_lims = [kuroshio_sw_corner[1],kuroshio_ne_corner[1]]

# Define mission phase (1-day repeat vs science) and 
# cycles we are interested in
# Use sph_calval_swath for the 1-day repeats
path_to_sph_file="../orbit_data/sph_calval_swath.zip"
# Cycles 474 - 578 are from the 1-day repeat 
cycles = [str(c_num).zfill(3) for c_num in range(474,480)]

pass_IDs_list = download_swaths.find_swaths(kuroshio_sw_corner, kuroshio_ne_corner,
                                           path_to_sph_file=path_to_sph_file)

cycle_data = {}
for cycle in cycles:
    cycle_data[cycle] = data_loaders.load_cycle(L3_kuroshio_path,fields=["time","ssha","ssha_unedited","ssha_noiseless","sigma0"],
                                                cycle=cycle,pass_ids=pass_IDs_list,subset=False,lats=lat_lims
                                               )


Loading SWOT_L3_LR_SSH_Unsmoothed_474_019_20230329T132147_20230329T141138_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_475_004_20230330T002601_20230330T011706_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_476_019_20230331T130302_20230331T135253_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_478_004_20230401T235753_20230402T004859_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_478_019_20230402T124418_20230402T133408_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_479_004_20230402T234831_20230403T001052_v1.0.2_kuroshio.nc
Loading SWOT_L3_LR_SSH_Unsmoothed_479_019_20230403T123455_20230403T132446_v1.0.2_kuroshio.nc


In [3]:
################################################################
# Helper scripts
################################################################
def nearest(items, pivot):
    return min(items, key=lambda x: abs(x - pivot))

import datetime
def avg(dates):
  ref_date = np.datetime64('1900-01-01T00:00:00')
  return ref_date + sum([date - ref_date for date in dates], np.timedelta64()) / len(dates)

################################################################
# This script makes some nice DUACS/SWOT plots
# Standalone, shouldn't be maintained after a week or two
################################################################


def plot_swaths_save(cmin, cmax):
    # Load the DUACS data for the Agulhas region
    duacs_california = xr.open_dataset('../tmp/copernicus-data/cmems_duacs-0.25deg_P1D_Kuroshio1.nc')
    #duacs_agulhas = xr.open_dataset('../tmp/copernicus-data/cmems_duacs-0.25deg_P1D_Agulhas.nc')
    
    duacs_california["uvgos"] = (duacs_california.ugos**2 + duacs_california.vgos**2)**(1/2)
    #duacs_agulhas["uvgos"] = (duacs_agulhas.ugos**2 + duacs_agulhas.vgos**2)**(1/2)
    
    # Specify the path to the subsetted data
    L3_agulhas_path = f"../../../SWOT_L3/Unsmoothed_kuroshio"
    
    # Define domain
    # Rough Kuroshio
    sw_corner = [150.0, 20.0]
    ne_corner = [164.0, 40.0]
    extent_lims=[151.5,164.,19,41.0]
    # Big Agulhas
    # sw_corner = [5.0, -50.0]
    # ne_corner = [25.0, -30.0]
    # lat_lims = [sw_corner[1],ne_corner[1]]
    
    # Define mission phase (1-day repeat vs science) and 
    # cycles we are interested in
    # Use sph_calval_swath for the 1-day repeats
    path_to_sph_file="../orbit_data/sph_calval_swath.zip"
    # Cycles 474 - 578 are from the 1-day repeat 
    cycles = [str(c_num).zfill(3) for c_num in range(cmin,cmax)]

    subplot_kw = {'projection': ccrs.PlateCarree()
                 }
    
    ssha_plot_kw = {"cmap":cm.cm.balance,
                    "transform":ccrs.PlateCarree(),
                    "vmin":-0.4,"vmax":0.4
                   }
    
    suvgos_plot_kw = {"cmap":cm.cm.deep_r,
                        "transform":ccrs.PlateCarree(),
                        "vmin":0,"vmax":2
                       }
    
    pass_ids = [
               # "001","016" # Agulhas
               # "013","026" # California
               "004", "019" # Kuroshio
               ]
    # lats = [-44,-32] # Agulhas
    # lats = [29,38.5] # California
    lats = [20, 40] # Kuroshio
    load_fields = ["time","ssha_unedited","ssha","ugos","vgos",]        
    for cycle in cycles:
        print(cycle)
        # Load data
        cycle_data = data_loaders.load_cycle(L3_agulhas_path,fields=load_fields,
                                                    cycle=cycle,pass_ids=pass_ids,subset=True,lats=lats
                                                   )
        # Add geostrophic velocities to swaths
        for swath in cycle_data:
            swath["uvgos"] = (swath["ugos"]**2 + swath["vgos"]**2)**(0.5)
        
        if len(cycle_data) < 1:
            print(f"Failed to load swaths for {cycle}")
            pass
            
        else:
            # First do some calculations:
            # Compute the along-swath power spectra (ps) for each swath in the cycle
            test_spectra = swot_utils.compute_power_spectra_xrft(cycle_data)
            # Compute the along-swath ps excluding the swath edges to decrease noise
            test_spectra_ne = swot_utils.compute_power_spectra_xrft(cycle_data,subset=True,lim0=40,lim1=200)
            print("Did spectra calc")
            
            cycle_ID = cycle_data[0].cycle
            pass_ID = cycle_data[0].pass_ID

            # Select closest (in time) DUACS dataset
            times = []
            for swath in cycle_data:
                times.append(swath.time.mean().values)
            avg_time = avg(times)
            print(f"avg_time {avg_time}")
            print(f"nearest duacs time: {nearest(duacs_california.time, avg_time)}")
            nearest_duacs_time = nearest(duacs_california.time, avg_time)
            duacs_plot = duacs_california.sel(time=nearest_duacs_time)
    
            fields=["ssha","uvgos"]
            # Specify the filename you want to save the figure to
            save_fig_name = f"{cycle}_Kuroshio_SSHA_DUACS_overlay_PSD"
                              
            # Initialize the figure
            fig = plt.figure(figsize=(10*len(fields)+10,20),dpi=300)
            # Create two subfigures so you can include the swaths on the left and the spectra on the right
            subfigs = fig.subfigures(1, 2, wspace=0.08, width_ratios=[2, 0.9])    
            axsLeft = subfigs[0].subplots(1, 2, subplot_kw=subplot_kw)
            axsRight = subfigs[1].subplots(3,1, height_ratios=[1,1.8,1])
            axsRight[0].set_visible(False)
            axsRight[2].set_visible(False)

            for ax in axsLeft:
                ax.add_feature(cfeature.COASTLINE.with_scale('10m'))
                ax.add_feature(cfeature.LAND, edgecolor='none', facecolor='lightgray')
            
            # Plot the swaths on the lefthand subfigure
            axes = plotting_scripts.plot_cycle(cycle_data,title=cycle,fields=fields,vmins=[ssha_plot_kw["vmin"],suvgos_plot_kw["vmin"]],vmaxes=[ssha_plot_kw["vmax"],suvgos_plot_kw["vmax"]],
                                          ssha_plot_kw = {"transform":ccrs.PlateCarree(),
                                                            "s":2,"marker":".",
                                                            "alpha":1,"linewidths":0},
                                          cmaps=[cm.cm.balance,cm.cm.deep_r],cbar_titles=["SSHA (m)","$\sqrt{U_{g}^{2}+V_{g}^{2}}$"],
                                          dpi=300,set_extent=False,extent_lims=None,
                                          plot_bathymetry=False,
                                          axes=axsLeft)
            # Plot DUACS
            axsLeft[0].pcolor(duacs_plot.longitude, duacs_plot.latitude,
                            duacs_plot.sla, **ssha_plot_kw,zorder=1,alpha=1)
            
            axsLeft[1].pcolor(duacs_plot.longitude, duacs_plot.latitude,
                            duacs_plot.uvgos, **suvgos_plot_kw,zorder=1,alpha=1)
            
            for ax in axsLeft:
                ax.set_extent(extent_lims, crs=ccrs.PlateCarree())
        
                txt = ax.text(extent_lims[0],extent_lims[3],f"DUACS SSHA {nearest_duacs_time.values.astype('datetime64[s]')}",
                     fontsize=15,weight='bold',zorder=100,color="grey")
                txt.set_bbox(dict(facecolor='white', alpha=1, edgecolor='k'))
    
            ##################################################################################################################
            # PSD Calc
            # THIS IS A HACKY WAY TO TACK ON A PSD PLOT ON THE RIGHT
            # "plt_axes" is the axis where we plot the PSD on axsRight
            plt_axes = axsRight[1]

            # You need to adjust the data format for test_spectra, it is unintuitive and 
            # strange right now. The current format is test_spectra[frequencies[swaths],amplitudes[swaths]]
            for i, f in enumerate(test_spectra[0]):
                line, = plt_axes.loglog(test_spectra[0][i].mean(axis=0),
                                   test_spectra[1][i].mean(axis=0),
                                   label=f"Cycle{cycle_ID} PID{pass_ID}",alpha=0.6,linewidth=2)
                
                plt_axes.loglog(test_spectra_ne[0][i].mean(axis=0),
                           test_spectra_ne[1][i].mean(axis=0),
                           label=f"Cycle{cycle_ID} PID{pass_ID} No Edges",
                           alpha=1,color=line.get_color(),linewidth=0.5)
        
            x = np.linspace(0.01, 1, 100)
            plt_axes.loglog(x*.15, (x**-4)/10**2.5, color='blue',label="k^{-4}",linestyle="dashed",alpha=0.7)
            plt_axes.loglog(x*.2, (x**-2)/10**.8, color='red',label="k^{-2}",linestyle="dashed",alpha=0.7)
            
            plt_axes.hlines([2, 10,20,200],10**-3,10,color="grey",linestyle="--",alpha=0.7)
            plt_axes.vlines([1/15,1/5,1, 2],0,10**5,color="grey",linestyle="--",alpha=0.7)
            
            plt_axes.set_title(f"PSDs for cycle {cycle}",fontsize=45)
            plt_axes.set_xlabel('Wavenumber (cpkm)',fontsize=20)
            plt_axes.set_ylabel('Power Spectral Density $cm^{2}/cpkm$',fontsize=20)
            plt.sca(plt_axes)
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            
            plt_axes.vlines([1/150,1/100,1/50,1/20,1/15,1/10,1/1,1/.5],0,10**5,linestyle="--",alpha=0.7)
            plt_axes.legend(fontsize=15)
            plt_axes.set_ylim(10**-1.5,10**4)
            plt_axes.set_xlim(10**-3,3)
            plt_axes.grid()
            ##################################################################################################################
        
            fig.tight_layout()     
            
            #fig.savefig(f"../movie_figs/{save_fig_name}.png",bbox_inches='tight',dpi=200)

    return

############################################################
# Argument parser to do sepcific cycles
############################################################
# I really don't know what I'm doing here...
# See https://docs.python.org/dev/library/argparse.html
parser = argparse.ArgumentParser(prog="Program name",
                                 description="what do I do?",
                                 epilog="like tears in the rain")
parser.add_argument("cmin")
parser.add_argument("cmax")
args = parser.parse_args()

plot_swaths_save(int(args.cmin), int(args.cmax))




usage: Program name [-h] cmin cmax
Program name: error: the following arguments are required: cmin, cmax


SystemExit: 2