In [1]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client
import dask.dataframe as dd

cluster = PBSCluster(cores=4, memory="20GB", project='PangeoKEspectrumeNATL60', walltime='04:00:00')
cluster.scale(12)
cluster

VBox(children=(HTML(value='<h2>PBSCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .d…

In [2]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.120.43.56:50782  Dashboard: http://10.120.43.56:8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [3]:
import sys, glob
import numpy as np
import xarray as xr
import xscale.spectral.fft as xfft
import xscale 
import Wavenum_freq_spec_func as wfs
import time

In [4]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

import matplotlib.pyplot as plt
import matplotlib.cm as mplcm
from matplotlib.colors import LogNorm

seq_cmap = mplcm.Blues
div_cmap = mplcm.seismic


In [5]:
%%time
dsu=xr.open_zarr('/work/ALT/odatis/eNATL60/zarr/eNATL60-BLBT02-SSU-1h')
dsv=xr.open_zarr('/work/ALT/odatis/eNATL60/zarr/eNATL60-BLBT02-SSV-1h')


CPU times: user 10.3 s, sys: 592 ms, total: 10.9 s
Wall time: 10.9 s


In [6]:
dsu

<xarray.Dataset>
Dimensions:       (time_counter: 8760, x: 8354, y: 4729)
Coordinates:
  * time_counter  (time_counter) datetime64[ns] 2009-07-01T00:30:00 ... 2010-06-30T23:30:00
Dimensions without coordinates: x, y
Data variables:
    nav_lat       (y, x) float32 dask.array<shape=(4729, 8354), chunksize=(296, 1045)>
    nav_lon       (y, x) float32 dask.array<shape=(4729, 8354), chunksize=(296, 1045)>
    sozocrtx      (time_counter, y, x) float32 dask.array<shape=(8760, 4729, 8354), chunksize=(24, 120, 120)>
Attributes:
    Conventions:  CF-1.6
    NCO:          4.4.6
    TimeStamp:    08/01/2019 09:34:23 +0100
    description:  ocean U grid variables
    file_name:    eNATL60-BLBT02_1h_20090630_20090704_gridU-2D_20090701-20090...
    history:      Fri May 24 23:57:12 2019: ncks -O -F -v sozocrtx /store/CT1...
    ibegin:       0
    jbegin:       0
    name:         /scratch/tmp/3746956/eNATL60-BLBT02_1h_20090630_20090704_gr...
    ni:           8354
    nj:           9
    timeStam

In [7]:
dsv

<xarray.Dataset>
Dimensions:       (time_counter: 8760, x: 8354, y: 4729)
Coordinates:
  * time_counter  (time_counter) datetime64[ns] 2009-07-01T00:30:00 ... 2010-06-30T23:30:00
Dimensions without coordinates: x, y
Data variables:
    nav_lat       (y, x) float32 dask.array<shape=(4729, 8354), chunksize=(296, 1045)>
    nav_lon       (y, x) float32 dask.array<shape=(4729, 8354), chunksize=(296, 1045)>
    somecrty      (time_counter, y, x) float32 dask.array<shape=(8760, 4729, 8354), chunksize=(24, 120, 120)>
Attributes:
    Conventions:  CF-1.6
    NCO:          4.4.6
    TimeStamp:    08/01/2019 09:34:23 +0100
    description:  ocean V grid variables
    file_name:    eNATL60-BLBT02_1h_20090630_20090704_gridV-2D_20090701-20090...
    history:      Sat May 25 00:47:37 2019: ncks -O -F -v somecrty /store/CT1...
    ibegin:       0
    jbegin:       0
    name:         /scratch/tmp/3746956/eNATL60-BLBT02_1h_20090630_20090704_gr...
    ni:           8354
    nj:           9
    timeStam

In [8]:
%%time
lat=dsu['nav_lat']
lon=dsu['nav_lon']
 
latmin = 40.0; latmax = 45.0;
lonmin = -40.0; lonmax = -35.0;

domain = (lonmin<lon) * (lon<lonmax) * (latmin<lat) * (lat<latmax)
where = np.where(domain)

#get indice
jmin = np.min(where[0][:])
jmax = np.max(where[0][:])
imin = np.min(where[1][:])
imax = np.max(where[1][:])

latbox=lat[jmin:jmax,imin:imax]
lonbox=lon[jmin:jmax,imin:imax]


CPU times: user 796 ms, sys: 83 ms, total: 879 ms
Wall time: 1.49 s


In [9]:
%%time

print('Select dates')
u_JFM=dsu.sel(time_counter=slice('2010-01-01','2010-03-31'))['sozocrtx']
v_JFM=dsv.sel(time_counter=slice('2010-01-01','2010-03-31'))['somecrty']


print('Select box area')
u_JFM_box=u_JFM[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})
v_JFM_box=v_JFM[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})



# - get dx and dy
print('get dx and dy')
dx_JFM,dy_JFM = wfs.get_dx_dy(u_JFM_box[0],lonbox,latbox)


#... Detrend data in all dimension ...
print('Detrend data in all dimension')
u_JFM = wfs.detrendn(u_JFM_box,axes=[0,1,2])
v_JFM = wfs.detrendn(v_JFM_box,axes=[0,1,2])

#... Apply hanning windowing ...') 
print('Apply hanning windowing')
u_JFM = wfs.apply_window(u_JFM, u_JFM.dims, window_type='hanning')
v_JFM = wfs.apply_window(v_JFM, v_JFM.dims, window_type='hanning')


# - get derivatives
derivatives_JFM = wfs.velocity_derivatives(u_JFM, v_JFM, xdim='x', ydim='y', dx={'x': dx_JFM, 'y': dy_JFM})
dudx_JFM = derivatives_JFM['u_x']; dudy_JFM = derivatives_JFM['u_y']
dvdx_JFM = derivatives_JFM['v_x']; dvdy_JFM = derivatives_JFM['v_y']

# - compute terms
phi1_JFM = u_JFM*dudx_JFM + v_JFM*dudy_JFM
phi2_JFM = u_JFM*dvdx_JFM + v_JFM*dvdy_JFM

u_JFMhat = xfft.fft(u_JFM, dim=('time_counter', 'x', 'y'), dx={'x': dx_JFM, 'y': dx_JFM}, sym=True)
v_JFMhat = xfft.fft(v_JFM, dim=('time_counter', 'x', 'y'), dx={'x': dx_JFM, 'y': dx_JFM}, sym=True)

phi1_JFM_hat = xfft.fft(phi1_JFM, dim=('time_counter', 'x', 'y'), dx={'x': dx_JFM, 'y': dx_JFM}, sym=True)
phi2_JFM_hat = xfft.fft(phi2_JFM, dim=('time_counter', 'x', 'y'), dx={'x': dx_JFM, 'y': dx_JFM}, sym=True)

tm1_JFM = (u_JFMhat.conj())*phi1_JFM_hat
tm2_JFM = (v_JFMhat.conj())*phi2_JFM_hat

# - computer transfer
Nk_JFM,Nj_JFM,Ni_JFM = u_JFM.shape
transfer_2D_JFM = -1.0*(tm1_JFM + tm2_JFM)/np.square(Ni_JFM*Nj_JFM)
transfer_term_JFM = transfer_2D_JFM.real

#... Get frequency and wavenumber ... 
print('Get frequency and wavenumber')
ffrequency_JFM = u_JFMhat.f_time_counter
kx_JFM = u_JFMhat.f_x
ky_JFM = u_JFMhat.f_y

#... Get istropic wavenumber ... 
print('Get istropic wavenumber')
wavenumber_JFM,kradial_JFM = wfs.get_wavnum_kradial(kx_JFM,ky_JFM)

#... Get numpy array ... 
print('Get numpy array')
var_psd_np_JFM = transfer_term_JFM.values

#... Get 2D frequency-wavenumber field ... 
print('Get transfer')
transfer_JFM = wfs.get_f_k_in_2D(kradial_JFM,wavenumber_JFM,var_psd_np_JFM) 

print('Get flux')
flux_JFM = wfs.get_flux_in_1D(kradial_JFM,wavenumber_JFM,var_psd_np_JFM)


Select dates
Select box area
get dx and dy
Detrend data in all dimension
Apply hanning windowing
Get frequency and wavenumber
Get istropic wavenumber
Get numpy array
Get transfer
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
Get flux
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
CPU times: user 16min 42s, sys: 1min 15s, total: 17min 58s
Wall time: 17min 57s


In [10]:
# Save to Netscdf file
# - build dataarray
print('Save to Netscdf file')
transfer_JFM_da = xr.DataArray(transfer_JFM,dims=['frequency','wavenumber'],name="transfer",coords=[ffrequency_JFM ,wavenumber_JFM])
flux_JFM_da = xr.DataArray(flux_JFM,dims=['frequency','wavenumber'],name="flux",coords=[ffrequency_JFM,wavenumber_JFM])
transfer_JFM_da.attrs['Name'] = 'KE_Transfer_Flux_JFM_w_k_from_1h_eNATL60-BLBT02.nc'

transfer_JFM_da.to_dataset().to_netcdf(path='KE_Transfer_Flux_JFM_w_k_from_1h_eNATL60-BLBT02.nc',mode='w',engine='scipy')
flux_JFM_da.to_dataset().to_netcdf(path='KE_Transfer_Flux_JFM_w_k_from_1h_eNATL60-BLBT02.nc',mode='a',engine='scipy')


Save to Netscdf file


In [11]:
%%time
print('Select dates')
u_JAS_pre09=dsu.sel(time_counter=slice('2009-07-01','2009-09-02'))['sozocrtx']
v_JAS_pre09=dsv.sel(time_counter=slice('2009-07-01','2009-09-02'))['somecrty']

u_JAS_post09=dsu.sel(time_counter=slice('2009-09-08','2009-09-16'))['sozocrtx']
v_JAS_post09=dsv.sel(time_counter=slice('2009-09-08','2009-09-16'))['somecrty']

print('Select box area')
u_JAS_box_pre09=u_JAS_pre09[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})
v_JAS_box_pre09=v_JAS_pre09[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})

u_JAS_box_post09=u_JAS_post09[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})
v_JAS_box_post09=v_JAS_post09[:,jmin:jmax,imin:imax].chunk({'time_counter':10,'x':120,'y':120})

u_JAS_box=xr.concat([u_JAS_box_pre09,u_JAS_box_post09],dim='time_counter')
v_JAS_box=xr.concat([v_JAS_box_pre09,v_JAS_box_post09],dim='time_counter')

# - get dx and dy
print('get dx and dy')
dx_JAS,dy_JAS = wfs.get_dx_dy(u_JAS_box[0],lonbox,latbox)


#... Detrend data in all dimension ...
print('Detrend data in all dimension')
u_JAS = wfs.detrendn(u_JAS_box,axes=[0,1,2])
v_JAS = wfs.detrendn(v_JAS_box,axes=[0,1,2])

#... Apply hanning windowing ...') 
print('Apply hanning windowing')
u_JAS = wfs.apply_window(u_JAS, u_JAS.dims, window_type='hanning')
v_JAS = wfs.apply_window(v_JAS, v_JAS.dims, window_type='hanning')


# - get derivatives
print('velocity derivatives')
derivatives_JAS = wfs.velocity_derivatives(u_JAS, v_JAS, xdim='x', ydim='y', dx={'x': dx_JAS, 'y': dy_JAS})
dudx_JAS = derivatives_JAS['u_x']; dudy_JAS = derivatives_JAS['u_y']
dvdx_JAS = derivatives_JAS['v_x']; dvdy_JAS = derivatives_JAS['v_y']

# - compute terms
print('computer terms')
phi1_JAS = u_JAS*dudx_JAS + v_JAS*dudy_JAS
phi2_JAS = u_JAS*dvdx_JAS + v_JAS*dvdy_JAS

print('fft u v ')
u_JAShat = xfft.fft(u_JAS, dim=('time_counter', 'x', 'y'), dx={'x': dx_JAS, 'y': dx_JAS}, sym=True)
v_JAShat = xfft.fft(v_JAS, dim=('time_counter', 'x', 'y'), dx={'x': dx_JAS, 'y': dx_JAS}, sym=True)

print('fft phi')
phi1_JAS_hat = xfft.fft(phi1_JAS, dim=('time_counter', 'x', 'y'), dx={'x': dx_JAS, 'y': dx_JAS}, sym=True)
phi2_JAS_hat = xfft.fft(phi2_JAS, dim=('time_counter', 'x', 'y'), dx={'x': dx_JAS, 'y': dx_JAS}, sym=True)

print('multiply')
tm1_JAS = (u_JAShat.conj())*phi1_JAS_hat
tm2_JAS = (v_JAShat.conj())*phi2_JAS_hat

# - computer transfer
print('Compute transfer')
Nk_JAS,Nj_JAS,Ni_JAS = u_JAS.shape
transfer_2D_JAS = -1.0*(tm1_JAS + tm2_JAS)/np.square(Ni_JAS*Nj_JAS)
transfer_term_JAS = transfer_2D_JAS.real

#... Get frequency and wavenumber ... 
print('Get frequency and wavenumber')
ffrequency_JAS = u_JAShat.f_time_counter
kx_JAS = u_JAShat.f_x
ky_JAS = u_JAShat.f_y

#... Get istropic wavenumber ... 
print('Get istropic wavenumber')
wavenumber_JAS,kradial_JAS = wfs.get_wavnum_kradial(kx_JAS,ky_JAS)

#... Get numpy array ... 
print('Get numpy array')
var_psd_np_JAS = transfer_term_JAS.values

#... Get 2D frequency-wavenumber field ... 
print('Get transfer')
transfer_JAS = wfs.get_f_k_in_2D(kradial_JAS,wavenumber_JAS,var_psd_np_JAS) 

print('Get flux')
flux_JAS = wfs.get_flux_in_1D(kradial_JAS,wavenumber_JAS,var_psd_np_JAS)


Select dates
Select box area
get dx and dy
Detrend data in all dimension
Apply hanning windowing
velocity derivatives
computer terms
fft u v 
fft phi
multiply
Compute transfer
Get frequency and wavenumber
Get istropic wavenumber
Get numpy array
Get transfer
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
Get flux
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
CPU times: user 13min 57s, sys: 1min 12s, total: 15min 10s
Wall time: 15min 25s


In [12]:
# Save to Netscdf file
# - build dataarray
print('Save to Netscdf file')
transfer_JAS_da = xr.DataArray(transfer_JAS,dims=['frequency','wavenumber'],name="transfer",coords=[ffrequency_JAS ,wavenumber_JAS])
flux_JAS_da = xr.DataArray(flux_JAS,dims=['frequency','wavenumber'],name="flux",coords=[ffrequency_JAS,wavenumber_JAS])
transfer_JAS_da.attrs['Name'] = 'KE_Transfer_Flux_JAS_w_k_from_1h_eNATL60-BLBT02.nc'

transfer_JAS_da.to_dataset().to_netcdf(path='KE_Transfer_Flux_JAS_w_k_from_1h_eNATL60-BLBT02.nc',mode='w',engine='scipy')
flux_JAS_da.to_dataset().to_netcdf(path='KE_Transfer_Flux_JAS_w_k_from_1h_eNATL60-BLBT02.nc',mode='a',engine='scipy')


Save to Netscdf file
