In [2]:
import argparse 
import ast

import cf_xarray
#import cftime
import geocat.comp as gcomp
#import holoviews as hv
#import hvplot
#import hvplot.xarray
#import intake
import numpy as np
#import pop_tools
import xarray as xr
import xesmf as xe

#from distributed import Client
#from ncar_jobqueue import NCARCluster
#from pop_tools.grid import _compute_corners

import logging 
import netCDF4 as nc
from numba import vectorize, float64, jit, njit

## Questions: 
Do certain variables still need to be reversed?  
Is surface geopotential and lowest layer geopotential treated the same here? 
Does geocat's interp function work across time? 
Is it better to build an atomic function and use appy_ufunc or build in vectorization? 
If you use numpy functions on xarray dataarrays, does xarray intercept the function to correct the dimensions? 

TO DO:  
Implement weight reuse  
Implement pressure levels   
Fix wrong standard name for SST  
Add hooks for user-defined reference pressure?  
Add hook to keep metadata  
Fix up METGRID default value  
Implement unit converter (pint) or unit checker  
NOTES:   
Created a branch for a version that works on cf-compliant data  
 


In [None]:
#Command line option handling ----------------------------------------------------------------------------------
parser = argparse.ArgumentParser()  
logging.basicConfig(level=logging.DEBUG)
current_log_level = logging.getLogger().getEffectiveLevel() 

parser.add_argument('CASE',type=str, help='One of the following IPCC Climate Scenarios: 20THC/RCP85/RCP60/RCP45')
parser.add_argument('--o',type=str,help='Output directory path')
parser.add_argument('--mode','-m',type=str,help='Set logging mode: DEBUG/INFO/WARNING/ERROR/CRITICAL')
parser.add_argument('--plev',type=str, help="File name of desired output pressure levels")
parser.add_argument('--weights',type=str, help="File name if reusing regridding weights")

if current_log_level != 10: 
    args = parser.parse_args()

In [3]:
#File Handling ----------------------------------------------------------------------------------

logging.info("Opening data files...")

in_ta = xr.load_dataset("atmos_ta.nc")         # 6-hourly 3-d T
in_ua = xr.open_dataset("atmos_ua.nc")         # 6-hourly 3-d U
in_va = xr.open_dataset("atmos_va.nc")         # 6-hourly 3-d V
in_hus = xr.open_dataset("atmos_hus.nc")       # 6-hourly 3-d Q
in_ps = xr.load_dataset("atmos_ps.nc")         # 6-hourly surface pressure
in_zsfc = xr.open_dataset("atmos_zsfc.nc")     # static surface geopotential
#in_lmask = xr.open_dataset("atmos_lmask.nc")   # static land mask
#in_snw = xr.open_dataset("atmos_snw_1.nc")     # monthly SWE
#in_mrlsl = xr.open_dataset("atmos_mrlsl_1.nc") # monthly soil moisture
#in_ts = xr.open_dataset("atmos_ts_1.nc")       # monthly skin temp
#in_tsl = xr.open_dataset("atmos_tsl_1.nc")     # monthly soil temp
#in_tos = xr.open_dataset("atmos_tos_1.nc")     # daily SST on pop grid (gaussian)
#in_sic = xr.open_dataset("atmos_sic_1.nc")     # daily SEAICE % on POP grid (gaussian)

In [None]:
#Regrid SST and SEA ICE fields to CESM Atmospheric Domain ----------------------------------------------------------------------------------

logging.info('Converting Parallel Ocean Program data to coordinate system of atmospheric grid...')

SST = in_tos.cf['surface_temperature']
#Create a mask (not needed for interpolating to atmospheric grid, but just in case there are missing values)
#NOTE THAT THIS CF REFERENCE IS WRONG. SST IS THE CORRECT STANDARD NAME WHICH NEEDS TO BE CORRECTED IN THE DATA
in_tos["mask"] = ~SST.cf.isel(time=0).isnull()

#Regrids SST grid to whatever the atmospheric grid is automatically
regrid = xe.Regridder(in_tos, in_ta, method = 'bilinear', periodic=True, unmapped_to_nan=True)
regrid.to_netcdf('weights_gx1v6_latlon.nc') #write out weights for reuse 

regridded_SST = regrid(in_tos)

if current_log_level == 10: 
    print(regridded_SST)
#regridded_SST.to_netcdf('python_regrid.nc')
#
#use some sort of broadcasting or view here to clone to a 6-hrly variable

In [4]:
#Prepare Variables for Interpolation ----------------------------------------------------------------------------------

hyam = in_ta.cf['hyam'] 
hybm = in_ta.cf['hybm']
hyai = in_ta.cf['hyai']
hybi = in_ta.cf['hybi']

surf_pressure = in_ps.cf['PS']

phi_surf = in_zsfc['PHIS']
phi_surf.coords['lat'] = surf_pressure.coords['lat']
phi_surf.coords['lon'] = surf_pressure.coords['lon']
temp = in_ta["T"]    


Check that this is correct - it seems like the coordinate system is somewhat reversed.

In [5]:
#%%timeit -r 10

@vectorize([float64(float64,float64,float64,float64)],nopython=True)
def pres_on_hybrid_ccm_atomic(pressure_surf, hyam_k, hybm_k, ref_pressure): 
    return hyam_k*ref_pressure + hybm_k*pressure_surf 

#need assertion that missing values are correct 

def pres_on_hybrid_ccm(pressure_surf : xr.DataArray , hyam: xr.DataArray, hybm: xr.DataArray, ref_pressure =  100000): 
    return xr.apply_ufunc(pres_on_hybrid_ccm_atomic, pressure_surf, hyam, hybm, ref_pressure) 

#This cast was present in the original ncl script, so we are keeping it here. 
P_hybrid = pres_on_hybrid_ccm(surf_pressure,hyam,hybm).astype(np.single)
#print(P_hybrid)

ncl_P_hybrid = xr.open_dataarray('ncl_P_hybrid.nc')
ncl_P_hybrid = ncl_P_hybrid.rename({'ncl0':'time', 'ncl1':'lev', 'ncl2':'lat', 'ncl3':'lon'})
print(ncl_P_hybrid)

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


In [None]:
ncl_P_hybrid.coords['time'] = P_hybrid.coords['time']
ncl_P_hybrid.coords['lev'] = P_hybrid.coords['lev'][::-1]
ncl_P_hybrid.coords['lat'] = P_hybrid.coords['lat']
ncl_P_hybrid.coords['lon'] = P_hybrid.coords['lon']

print(P_hybrid.isel(time = 0, lat = 0, lon = 0))
print(ncl_P_hybrid.isel(time = 0, lat = 0, lon = 0, lev = slice(None,None,-1)).values)
np.max(P_hybrid-ncl_P_hybrid)


Piecewise, Piecewise with jit

Atomic execution with vectorize

In [30]:
%%timeit -r 10

@vectorize([float64(float64,float64,float64,float64)],nopython=True)
def pslec_atomic(temp_bot,phi_s,ps,pressure_bot): 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    temp_surf = temp_bot*(1 + ALPHA_0*(ps/pressure_bot - 1)) #3b.5
    temp_bot_lapse = temp_surf + LAPSE_RATE*phi_s/GRAV_CONST #denoted T_0 in doc, 3b.13

    #These cases are partitions - there is no overlap in cases here. 
    if abs(phi_s/GRAV_CONST) < 1e-4: 
        psl = ps
    elif temp_surf >= 255 and temp_bot_lapse <= 290.5: 
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/temp_surf 
        psl =  ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf > 290.5 and temp_bot_lapse > 290.5: 
        T_star_modified = 1/2*(290.5+temp_surf) 
        psl = ps*np.exp(phi_s/SPEC_GAS_CONST/T_star_modified)
    elif temp_surf >=255 and temp_surf <= 290.5 and temp_bot_lapse > 290.5: 
        combo_term = 290.5-temp_surf
        psl = ps*np.exp(phi_s/SPEC_GAS_CONST/temp_surf*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf < 255 and temp_bot_lapse <= 290.5: 
        T_star_modified = 1/2*(255+temp_surf)
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star_modified 
        psl =  ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf < 255 and temp_bot_lapse > 290.5: 
        alpha = SPEC_GAS_CONST/phi_s*(290.5-temp_surf)
        T_star_modified = 1/2*(255+temp_surf)
        combo_term = alpha*phi_s/SPEC_GAS_CONST/T_star_modified 
        psl = ps*np.exp(combo_term/alpha*(1-1/2*(combo_term)+1/3*(combo_term)**2))    
    
    return psl 
    
def pslec(temp_bottom: xr.DataArray, phi_surf: xr.DataArray, pressure_surf: xr.DataArray, pressure_bot: xr.DataArray): 
    return xr.apply_ufunc(pslec_atomic,temp_bottom,phi_surf,pressure_surf,pressure_bot)


test_pressure = pslec(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))

1.1 s ± 6.87 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


Partial where with no jit

In [11]:
#%%timeit -r 10

def _pslec(t_bot, phi_sfc, ps, pressure_bot):
    R_d = 287.04  # dry air gas constant
    g_inv = 1 / 9.80616  # inverse of gravity
    alpha = 0.0065 * R_d * g_inv

    tstar = t_bot * (1 + alpha * (ps / pressure_bot - 1))
    hgt = phi_sfc * g_inv
    t0 = tstar + 0.0065 * hgt

    alph = xr.where((tstar <= 290.5) & (t0 > 290.5),
                    R_d / phi_sfc * (290.5 - tstar), alpha)

    alph = xr.where((tstar > 290.5) & (t0 > 290.5), 0, alph)
    tstar = xr.where((tstar > 290.5) & (t0 > 290.5), 0.5 * (290.5 + tstar),
                     tstar)

    tstar = xr.where((tstar < 255), 0.5 * (tstar + 255), tstar)

    beta = phi_sfc/R_d/tstar
    return ps*np.exp(beta*(1-alph*beta/2+(alph*beta)**2/3))

test_pressure = _pslec(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))


Where with jit

In [22]:
#%%timeit -r 10

@njit
def _pslec_jit(t_bot: np.array , phi_sfc: np.array , ps: np.array, pressure_bot: np.array):
    R_d = 287.04  # dry air gas constant
    g_inv = 1 / 9.80616  # inverse of gravity
    alpha = 0.0065 * R_d * g_inv

    tstar = t_bot * (1 + alpha * (ps / pressure_bot - 1))
    hgt = phi_sfc * g_inv
    t0 = tstar + 0.0065 * hgt

    alph = np.where((tstar <= 290.5) & (t0 > 290.5),
                    R_d / phi_sfc * (290.5 - tstar), alpha)

    alph = np.where((tstar > 290.5) & (t0 > 290.5), 0, alph)
    tstar = np.where((tstar > 290.5) & (t0 > 290.5), 0.5 * (290.5 + tstar),
                     tstar)

    tstar = np.where((tstar < 255), 0.5 * (tstar + 255), tstar)

    beta = phi_sfc/R_d/tstar
    return ps*np.exp(beta*(1-alph*beta/2+(alph*beta)**2/3))

temp_bc, phi_surf_bc, surf_pressure_bc, P_hybrid_bc = xr.broadcast(temp.isel(lev=-1),phi_surf,surf_pressure, P_hybrid.isel(lev=-1))
test_pressure = _pslec_jit(temp_bc.values, phi_surf_bc.values, surf_pressure_bc.values, P_hybrid_bc.values)

In [21]:
test_temp = temp.isel(lat=0,lon=0,time=0,lev=-1).values #* np.ones((2,2))
test_phi = phi_surf.isel(lon=0,lat=0).values #* np.ones((2,2))
test_surf = surf_pressure.isel(lon=0,lat=0,time=0).values# * np.ones((2,2))
test_hybrid = P_hybrid.isel(lat=0,lon=0,lev=-1,time=0).values# * np.ones((2,2))


print(test_temp)
print(test_phi)
print(test_surf)
print(test_hybrid)

test_pressure_atomic = pslec_atomic(test_temp, test_phi, test_surf, test_hybrid)
test_pressure_xr = pslec(test_temp,test_phi,test_surf,test_hybrid) 

print(test_pressure_atomic)
print(test_pressure_xr)
print(test_pressure.isel(lat=0,lon=0))


246.69228
27701.627573972324
68952.78
68439.5
99933.59998192343
99933.59998192343
<xarray.DataArray ()>
array(99933.59998192)
Coordinates:
    lat      float64 -90.0
    lev      float64 992.6
    lon      float64 0.0
    time     object 2006-01-01 00:00:00


For loop with no jit

In [None]:
def pslec_for(T_bot,phi_s,ps,pressure_bot): 

   LAPSE_RATE = 0.0065     #Kelvin per meter
   GRAV_CONST = 9.80616    #Meters per second per second
   SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

   ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
   psl = xr.zeros_like(ps)
   lats = pressure_bot.coords['lat']
   lons = pressure_bot.coords['lon']
   for lati in lats: 
      for long in lons: 
         if abs(phi_s.sel(lat=lati,lon=long)/GRAV_CONST):
            psl.loc[dict(lat=lati,lon=long)] = ps.sel(lat=lati,lon=long)
         else: 
            tstar = T_bot.sel(lat=lati,lon=long) * (1+ALPHA_0) * (ps.sel(lat=lati,lon=long)/pressure_bot.sel(lat=lati,lon=long)-1)
            T0 = tstar + LAPSE_RATE*phi_s.sel(lat=lati,lon=long)/GRAV_CONST
            if tstar <= 290.5 and T0 > 290.5: 
               ALPHA = SPEC_GAS_CONST/phi_s.sel(lat=lati,lon=long) * (290.5 - tstar)
            elif tstar > 290.5  and T0 > 290.5: 
               ALPHA = 0
               tstar = 0.5* (290.5+tstar)
            else: 
               ALPHA = ALPHA_0
            if tstar < 255: 
               tstar = 0.5* (255+tstar)
            BETA = phi_s.sel(lat=lati,lon=long)/ (SPEC_GAS_CONST*tstar)
            psl.loc[dict(lat=lati,lon=long)] = ps.sel(lat=lati,lon=long) * np.exp(BETA* (1-ALPHA*BETA/2+((ALPHA*BETA)**2)/3))
         
   return psl 

test_pressure = pslec_for(temp.isel(lev=-1,time=0), phi_surf, surf_pressure.isel(time=0), P_hybrid.isel(lev=-1,time=0))

For loop with jit

In [29]:
%%timeit -r 10

@njit
def pslec_for_jit(T_bot,phi_s,ps,pressure_bot): 

   LAPSE_RATE = 0.0065     #Kelvin per meter
   GRAV_CONST = 9.80616    #Meters per second per second
   SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

   ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
   ntim = ps.shape[0]
   nlat = ps.shape[1]
   nlon = ps.shape[2]
   psl = np.zeros_like(ps)
   for t in range(ntim):
      for lati in range(nlat): 
         for long in range(nlon): 
            if abs(phi_s[t,lati,long]/GRAV_CONST):
               psl[t,lati,long] = ps[t,lati,long]
            else: 
               tstar = T_bot[t,lati,long] * (1+ALPHA_0) * (ps[t,lati,long]/pressure_bot[t,lati,long]-1)
               T0 = tstar + LAPSE_RATE*phi_s[t,lati,long]/GRAV_CONST
               if tstar <= 290.5 and T0 > 290.5: 
                  ALPHA = SPEC_GAS_CONST/phi_s[t,lati,long] * (290.5 - tstar)
               elif tstar > 290.5  and T0 > 290.5: 
                  ALPHA = 0
                  tstar = 0.5* (290.5+tstar)
               else: 
                  ALPHA = ALPHA_0
               if tstar < 255: 
                  tstar = 0.5* (255+tstar)
               BETA = phi_s[t,lati,long]/ (SPEC_GAS_CONST*tstar)
               psl[t,lati,long] = ps[t,lati,long] * np.exp(BETA* (1-ALPHA*BETA/2+((ALPHA*BETA)**2)/3))
            
   return psl 

temp_bc, phi_surf_bc, surf_pressure_bc, P_hybrid_bc = xr.broadcast(temp.isel(lev=-1),phi_surf,surf_pressure, P_hybrid.isel(lev=-1))
test_pressure = pslec_for_jit(temp_bc.values, phi_surf_bc.values, surf_pressure_bc.values, P_hybrid_bc.values)

849 ms ± 8.56 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [11]:
ncl_pslec = xr.open_dataarray('ncl_pslec.nc')

print(test_pressure.isel(lat=0,lon=0))
print(ncl_pslec.isel(lat=0,lon=0).values)


<xarray.DataArray ()>
array(99933.59998192)
Coordinates:
    lat      float64 -90.0
    lev      float64 992.6
    lon      float64 0.0
    time     object 2006-01-01 00:00:00
100494.99


In [23]:
#Interpolate to Pressure Coordinates ----------------------------------------------------------------------------------
#logging.info("Interpolating variables to pressure coordinates...")

#if current_log_level == 10: print(temp); print(surf_pressure); print(in_zsfc['PHIS']); 
hyam = in_ta.cf['hyam'] 
hybm = in_ta.cf['hybm']
hyai = in_ta.cf['hyai']
hybi = in_ta.cf['hybi']
default_levs = np.array([1000.0, 975.0, 950.0, 925.0, 900.0, 850.0, 800.0, 750.0, 700.0, 650.0, 600.0, 550.0, 500.0, \
             450.0, 400.0, 350.0, 300.0, 250.0, 200.0, 150.0, 100.0, 70.0, 50.0, 30.0, 20.0, 10.0 ])


surf_pressure = in_ps.cf['PS']
phi_surf = in_zsfc['PHIS']
temp = in_ta["T"]    
fixed_phi_sfc = in_zsfc['PHIS']  
fixed_phi_sfc.coords['lat'] = surf_pressure.coords['lat']
fixed_phi_sfc.coords['lon'] = surf_pressure.coords['lon'] 

temp_interp = gcomp.interpolation.interp_hybrid_to_pressure(temp.isel(time=0),surf_pressure.isel(time=0),hyam,hybm, 
                                                            new_levels=default_levs, 
                                                            lev_dim = 'lev', 
                                                            method='log',
                                                            extrapolate=True,
                                                            variable='temperature',
                                                            t_bot=temp.isel(lev=-1,time=0),
                                                            phi_sfc=fixed_phi_sfc)

print(temp_interp)

<xarray.DataArray 'T' (plev: 26, lat: 192, lon: 288)>
dask.array<setitem, shape=(26, 192, 288), dtype=float32, chunksize=(26, 192, 288), chunktype=numpy.ndarray>
Coordinates:
    time     object 2006-01-01 00:00:00
  * lat      (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 87.17 88.12 89.06 90.0
  * plev     (plev) float64 1e+03 975.0 950.0 925.0 ... 50.0 30.0 20.0 10.0
  * lon      (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
Attributes:
    units:      K
    long_name:  Temperature
