In [1]:
import glob
import numpy as np
import time as tic
from hdf5storage import savemat
from netCDF4 import Dataset
import multiprocessing as mp

In [None]:
def wave_extract(year):

    files = glob.glob(filePath + 'f.e20.FXSD.f19_f19.001.cam.h1.' + str(year)+'[-0-9]*.nc')

    files.sort()

    # create nan arrays to allocate memory
    gpz_mean = np.full([145, len(files)], np.nan)

    amp_advection_DW1 = np.full([nlat, nlevs, len(files)], np.nan)
    phs_advection_DW1 = np.full([nlat, nlevs, len(files)], np.nan)
    

    for days in range(len(files)):
        print('Processing {} day = {}\n'.format(year,days+1))

        nc_fid = Dataset(files[days], 'r', format="NETCDF4")
        time = nc_fid['time'][:].data
        lev = nc_fid['lev'][:].data
        lat = nc_fid['lat'][:].data
        lon = nc_fid['lon'][:].data

        # checking length of the time is 8. if not write to a text file
        if len(time) != 8:

            gpz_mean[:, days] = np.mean(nc_fid['Z3'][:].data, axis=(0,2,3))/1000

            print('skipped {} day = {}\n'.format(year,days+1))

            continue


        T = nc_fid['T'][:].data
        
        Tlon, Tlat, Tp, Tt = np.gradient(T)
        
        U = nc_fid['U'][:].data
        V = nc_fid['V'][:].data
        omega = nc_fid['OMEGA'][:].data
        
        gpz_mean[:,days] = np.mean(nc_fid['Z3'][:].data, axis=(0,2,3))/1000
        
        lon1 = lon / 360*2*np.pi
        lat1 = lat / 180*np.pi
        lon_g = np.gradient(lon1)
        lat_g = np.gradient(lat1)
        pres_g = np.gradient(lev*100) # lev is in hPa, need to convert to Pa

        
        advection = np.full([Nt, nlevs, nlat, nlon], np.nan)

        for i in range(nlon):
            for j in range(nlat):
                for k in range(nlevs):
                    advection[:, k, j, i] = -(U[:, k, j, i] * Tlon[:, k, j, i]/a/np.cos(lat[j])/lon_g[i] + V[:, k, j, i] * Tlat[:, k, j, i]/a/lat_g[j] + omega[:, k, j, i] * Tp[:, k, j, i]/pres_g[k])
                    

        y = np.full([Nt, nlevs, nlat, Nz], np.nan, np.complex)
        ampl_s = np.full([Nt, nlevs, nlat, Nz], np.nan)
        ang_s = np.full([Nt, nlevs, nlat, Nz], np.nan)

        for i in range(nlat):
            for j in range(nlevs):

                y[:, j, i, :] = np.fft.fftshift(np.fft.fft2(advection[:, j, i, :]))  # 2D fft
                ampl_s[:, j, i, :] = np.absolute(y[:, j, i, :])/(Nz*Nt)  # amplitude
                ang_s[:, j, i, :] = np.angle(y[:, j, i, :])  # phase

                # the amplitude of waves are doubled due to symmetry and only positive frequencies are considered

                amp_advection_DW1[i, j, days] = 2 * ampl_s[freq == 1, j, i, wavenumber == 1]
                phs_advection_DW1[i, j, days] = -ang_s[freq == 1, j, i, wavenumber == 1]           
               
            
    DW1 = {"amp_advection_DW1": amp_advection_DW1, "phs_advection_DW1": phs_advection_DW1, "lev": lev, "lat": lat, "gpz_mean": gpz_mean}
    outfile_name = dest + 'DW1/' + 'WACCMX_advection_DW1_short_term_one_day_' + str(year) + '.mat'
    savemat(outfile_name, DW1, format='7.3')

In [None]:
if __name__ == "__main__":
    starttime = tic.time()

    filePath = '/data/avitharana/WACCMX/'  # origin
    dest = '/data/avitharana/WACCMX_heating/advection/'  # destination  
    
    a=6.37122e+6  # radius of Earth in meters
    
    Nt = 8  # Number of time steps in the 1-day window.

    nlat = 96
    nlevs = 145
    nlon = 144
    nsteps = 1  # number of days

    dz = 2.5  # Distance increment in degrees 360/64
    dt = 3  # time increment in Hours
    Nz = 144  # Number of samples available along longitude

    df = 1/(Nt*dt)  # temporal frequency
    dk = 1/(Nz*dz)  # spatial frequency

    wavenumber = (np.arange(1, Nz+1) - (Nz/2+1)) * dk * 360

    freq = (np.arange(1, Nt+1) - (Nt/2+1)) * df * 24  # we only use positive frequency here because of symmetry
    with np.errstate(divide='ignore'):
        freq_d = 1./freq

    num_workers = mp.cpu_count()
    pool = mp.Pool(num_workers)
    pool.map(wave_extract, range(1980, 1981))
    pool.close()
    endtime = tic.time()
    print(f"Time taken {endtime - starttime} seconds")