In [1]:
import argparse 
import ast

import cf_xarray
#import cftime
#import geocat.comp as gcomp
#import holoviews as hv
#import hvplot
#import hvplot.xarray
#import intake
import numpy as np
#import pop_tools
import xarray as xr
import xesmf as xe

#from distributed import Client
#from ncar_jobqueue import NCARCluster
#from pop_tools.grid import _compute_corners

import logging 
#import netCDF4 as nc
from numba import vectorize, float64, jit, njit, guvectorize

## Questions: 
Do certain variables still need to be reversed?  
Is surface geopotential and lowest layer geopotential treated the same here? 
Does geocat's interp function work across time? 
Is it better to build an atomic function and use appy_ufunc or build in vectorization? 
If you use numpy functions on xarray dataarrays, does xarray intercept the function to correct the dimensions? 

TO DO:  
Implement weight reuse  
Implement pressure levels   
Fix wrong standard name for SST  
Add hooks for user-defined reference pressure?  
Add hook to keep metadata  
Fix up METGRID default value  
Implement unit converter (pint) or unit checker  
NOTES:   
Created a branch for a version that works on cf-compliant data  
 


In [16]:
#Command line option handling ----------------------------------------------------------------------------------
parser = argparse.ArgumentParser()  
logging.basicConfig(level=logging.ERROR)
current_log_level = logging.getLogger().getEffectiveLevel() 

parser.add_argument('CASE',type=str, help='One of the following IPCC Climate Scenarios: 20THC/RCP85/RCP60/RCP45')
parser.add_argument('--o',type=str,help='Output directory path')
parser.add_argument('--mode','-m',type=str,help='Set logging mode: DEBUG/INFO/WARNING/ERROR/CRITICAL')
parser.add_argument('--plev',type=str, help="File name of desired output pressure levels")
parser.add_argument('--weights',type=str, help="File name if reusing regridding weights")

if current_log_level != 10: 
    args = parser.parse_args()

In [4]:
#File Handling ----------------------------------------------------------------------------------

logging.info("Opening data files...")

in_ta = xr.load_dataset("atmos_ta.nc")         # 6-hourly 3-d T
in_ua = xr.open_dataset("atmos_ua.nc")         # 6-hourly 3-d U
in_va = xr.open_dataset("atmos_va.nc")         # 6-hourly 3-d V
in_hus = xr.open_dataset("atmos_hus.nc")       # 6-hourly 3-d Q
in_ps = xr.load_dataset("atmos_ps.nc")         # 6-hourly surface pressure
in_zsfc = xr.open_dataset("atmos_zsfc.nc")     # static surface geopotential
#in_lmask = xr.open_dataset("atmos_lmask.nc")   # static land mask
#in_snw = xr.open_dataset("atmos_snw_1.nc")     # monthly SWE
#in_mrlsl = xr.open_dataset("atmos_mrlsl_1.nc") # monthly soil moisture
#in_ts = xr.open_dataset("atmos_ts_1.nc")       # monthly skin temp
#in_tsl = xr.open_dataset("atmos_tsl_1.nc")     # monthly soil temp
in_tos = xr.open_dataset("atmos_tos_1.nc")     # daily SST on pop grid (gaussian)
in_sic = xr.open_dataset("atmos_sic_1.nc")     # daily SEAICE % on POP grid (gaussian)

INFO:root:Opening data files...


In [4]:
#Regrid SST and SEA ICE fields to CESM Atmospheric Domain ----------------------------------------------------------------------------------

logging.info('Converting Parallel Ocean Program data to coordinate system of atmospheric grid...')

SST = in_tos.cf['surface_temperature']
#Create a mask (not needed for interpolating to atmospheric grid, but just in case there are missing values)
#NOTE THAT THIS CF REFERENCE IS WRONG. SST IS THE CORRECT STANDARD NAME WHICH NEEDS TO BE CORRECTED IN THE DATA
in_tos["mask"] = ~SST.cf.isel(time=0).isnull()

#Regrids SST grid to whatever the atmospheric grid is automatically
regrid = xe.Regridder(in_tos, in_ta, method = 'bilinear', periodic=True, unmapped_to_nan=True)
regrid.to_netcdf('weights_gx1v6_latlon.nc') #write out weights for reuse 

regridded_SST = regrid(in_tos)

if current_log_level == 10: 
    print(regridded_SST)
#regridded_SST.to_netcdf('python_regrid.nc')

ICE_DAILY = in_sic['aice_d']*0.01
in_sic["mask"] = ~ICE_DAILY.isel(time=0).isnull()
regrid = xe.Regridder(in_sic, in_ta, method = 'bilinear', periodic=True, unmapped_to_nan=True)

regridded_ICE_DAILY = regrid(in_sic)
#
#use some sort of broadcasting or view here to clone to a 6-hrly variable

#This is a lazy regrid, need to call .ffill if you want to force an evaluation. 
#regridded_SST_6HR = regridded_SST.resample(time='6HR')
#regridded_ICE_6HR = regridded_ICE_DAILY.resample(time='6HR')

INFO:root:Converting Parallel Ocean Program data to coordinate system of atmospheric grid...
DEBUG:numba.core.entrypoints:Loading extension: EntryPoint(name='init', value='sparse._numba_extension:_init_extension', group='numba_extensions')
DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1037)
           2	LOAD_GLOBAL(arg=0, lineno=1053)
           4	LOAD_ATTR(arg=1, lineno=1053)
           6	LOAD_FAST(arg=3, lineno=1053)
           8	LOAD_DEREF(arg=0, lineno=1053)
          10	LOAD_CONST(arg=1, lineno=1053)
          12	CALL_FUNCTION_KW(arg=2, lineno=1053)
          14	STORE_FAST(arg=4, lineno=1053)
          16	LOAD_CONST(arg=2, lineno=1054)
          18	STORE_FAST(arg=5, lineno=1054)
>         20	LOAD_FAST(arg=5, lineno=1056)
          22	LOAD_GLOBAL(arg=2, lineno=1056)
          24	LOAD_FAST(arg=1, lineno=1056)
          26	CALL_FUNCTION(arg=1, lineno=1056)
          28	COMPARE_OP(arg=0, lineno=1056)
          30	POP_JUMP_IF_FALSE(arg=154, lineno=1056)
   

<xarray.Dataset>
Dimensions:   (vertices: 4, lat: 192, lon: 288, time: 90)
Coordinates:
  * time      (time) object 2006-01-01 12:00:00 ... 2006-03-31 12:00:00
  * lat       (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon       (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
Dimensions without coordinates: vertices
Data variables:
    lon_bnds  (vertices, lat, lon) float32 nan nan nan nan ... 248.7 248.7 248.7
    lat_bnds  (vertices, lat, lon) float32 nan nan nan nan ... 89.65 89.65 89.65
    tos       (time, lat, lon) float32 nan nan nan nan ... 271.3 271.3 271.3
    mask      (lat, lon) bool True True True True True ... True True True True
Attributes:
    regrid_method:  bilinear




In [6]:
#Prepare Variables for Interpolation ----------------------------------------------------------------------------------

hyam = in_ta.cf['hyam'] 
hybm = in_ta.cf['hybm']
hyai = in_ta.cf['hyai']
hybi = in_ta.cf['hybi']

surf_pressure = in_ps.cf['PS']

phi_surf = in_zsfc['PHIS']
phi_surf.coords['lat'] = surf_pressure.coords['lat']
phi_surf.coords['lon'] = surf_pressure.coords['lon']
temp = in_ta["T"]    


Check that this is correct - it seems like the coordinate system is somewhat reversed.

In [18]:
%%timeit -r 10

@vectorize([float64(float64,float64,float64,float64)],nopython=True)
def pres_on_hybrid_ccm_atomic(pressure_surf, hyam_k, hybm_k, ref_pressure): 
    return hyam_k*ref_pressure + hybm_k*pressure_surf 

#need assertion that missing values are correct 

def pres_on_hybrid_ccm(pressure_surf : xr.DataArray , hyam: xr.DataArray, hybm: xr.DataArray, ref_pressure =  100000): 
    return xr.apply_ufunc(pres_on_hybrid_ccm_atomic, pressure_surf, hyam, hybm, ref_pressure) 

#This cast was present in the original ncl script, so we are keeping it here. 
P_hybrid = pres_on_hybrid_ccm(surf_pressure,hyam,hybm).astype(np.single)
#print(P_hybrid)

ncl_P_hybrid = xr.open_dataarray('ncl_P_hybrid.nc')
ncl_P_hybrid = ncl_P_hybrid.rename({'ncl0':'time', 'ncl1':'lev', 'ncl2':'lat', 'ncl3':'lon'})
print(ncl_P_hybrid)

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon


DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_FAST(arg=1, lineno=3)
           4	LOAD_FAST(arg=3, lineno=3)
           6	BINARY_MULTIPLY(arg=None, lineno=3)
           8	LOAD_FAST(arg=2, lineno=3)
          10	LOAD_FAST(arg=0, lineno=3)
          12	BINARY_MULTIPLY(arg=None, lineno=3)
          14	BINARY_ADD(arg=None, lineno=3)
          16	RETURN_VALUE(arg=None, lineno=3)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_initial=0)
DEBUG:numba.core.byteflow:dispatch pc=0, inst=NOP(arg=None, lineno=1)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=2, inst=LOAD_FAST(arg=1, lineno=3)
DEBUG:numba.core.byteflow:stack []
DEBUG:numba.core.byteflow:dispatch pc=4, inst=LOAD_FAST(arg=3, lineno=3)
DEBUG:numba.core.byteflow:stack ['$hyam_k2.0']
DEBUG:numba.core.byteflow:dispatch pc=6, in

<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon
<xarray.DataArray 'P_hybrid' (time: 360, lev: 26, lat: 192, lon: 288)>
[517570560 values with dtype=float32]
Dimensions without coordinates: time, lev, lat, lon
4.7 s ± 74 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [9]:
ncl_P_hybrid.coords['time'] = P_hybrid.coords['time']
ncl_P_hybrid.coords['lev'] = P_hybrid.coords['lev'][::-1]
ncl_P_hybrid.coords['lat'] = P_hybrid.coords['lat']
ncl_P_hybrid.coords['lon'] = P_hybrid.coords['lon']

print(P_hybrid.isel(time = 0, lat = 0, lon = 0))
print(ncl_P_hybrid.isel(time = 0, lat = 0, lon = 0, lev = slice(None,None,-1)).values)
np.max(P_hybrid-ncl_P_hybrid)


<xarray.DataArray (lev: 26)>
array([  354.4638 ,   738.88135,  1396.7214 ,  2394.4624 ,  3723.029  ,
        5311.4604 ,  7005.915  ,  8310.233  ,  9309.202  , 10484.4375 ,
       11867.041  , 13493.604  , 15407.169  , 17658.379  , 20306.814  ,
       23422.557  , 27088.068  , 31400.35   , 36473.527  , 42441.863  ,
       48821.242  , 54845.027  , 60110.285  , 64251.     , 66961.59   ,
       68439.5    ], dtype=float32)
Coordinates:
    lat      float64 -90.0
    lon      float64 0.0
    time     object 2006-01-01 00:00:00
  * lev      (lev) float64 3.545 7.389 13.97 23.94 ... 867.2 929.6 970.6 992.6
[  354.4638    738.88135  1396.7214   2394.4624   3723.029    5311.4604
  7005.915    8310.233    9309.202   10484.4375  11867.041   13493.604
 15407.169   17658.379   20306.814   23422.557   27088.068   31400.35
 36473.527   42441.863   48821.242   54845.027   60110.285   64251.
 66961.59    68439.5    ]


Piecewise

In [19]:
%%timeit -r 10

def pslec_piecewise(temp_bottom: xr.DataArray, phi_surf: xr.DataArray, pressure_surf: xr.DataArray, pressure_bot: xr.DataArray): 
    #Based on the NCAR Technical Note "Vertical Interpolation and Truncation of Model-Coordinate Data"
    #By Trenberth, Berry, Buja; Dec 1993 
    #temp_surf = T_*, temp_bottom = T_NL, pressure_surf = p_s, pressure_bot = p_NL, phi_surf = \phi_s

    #coordinate check 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    phi_surf_expanded = phi_surf.broadcast_like(pressure_surf)
    temp_surf = temp_bottom*(1 + ALPHA_0*(pressure_surf/pressure_bot - 1)) #3b.5
    temp_bot_lapse = temp_surf + LAPSE_RATE*phi_surf/GRAV_CONST #denoted T_0 in doc, 3b.13

    is_near_zero =  np.absolute(phi_surf_expanded/GRAV_CONST) < 1e-4 #this has precedence over the others 
    is_hot_low =    np.logical_and(temp_surf >= 255,temp_bot_lapse <= 290.5, where = ~is_near_zero)  
    is_hot_high =   np.logical_and(temp_surf > 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_mild_high =  np.logical_and(temp_surf <= 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_cold_low =   np.logical_and(temp_surf < 255, temp_bot_lapse <= 290.5, where = ~is_near_zero) 
    is_cold_high =  np.logical_and(temp_surf < 255, temp_bot_lapse > 290.5, where = ~is_near_zero)  

    #create vectorized functions 
    def psl_hot_low(ps,phi_s,T_star): 
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star 
        return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    def psl_hot_high(ps,phi_s,T_star): 
        T_star_modified = 1/2*(290.5+T_star) 
        return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star_modified)
    def psl_mild_high(ps,phi_s,T_star):
        combo_term = 290.5-T_star
        return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    def psl_cold_low(ps,phi_s,T_star): 
        T_star_modified = 1/2*(255+T_star)
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star_modified 
        return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    #Check this function - it may be wrong in the ncl code! 
    def psl_cold_high(ps,phi_s,T_star): 
        alpha = SPEC_GAS_CONST/phi_s*(290.5-T_star)
        T_star_modified = 1/2*(255+T_star)
        combo_term = alpha*phi_s/SPEC_GAS_CONST/T_star_modified 
        return ps*np.exp(combo_term/alpha*(1-1/2*(combo_term)+1/3*(combo_term)**2))    
    def psl_near_zero(ps,phi_s,T_star): return ps

    formulas =  [psl_near_zero,  psl_hot_low,    psl_hot_high,   psl_mild_high,  psl_cold_low,   psl_cold_high]
    cases =     [is_near_zero,   is_hot_low,     is_hot_high,    is_mild_high,   is_cold_low,    is_cold_high]

    psl = np.full_like(pressure_surf,np.nan)
    for where_case, formula in zip(cases,formulas): 
            psl[where_case] = formula(pressure_surf.to_numpy()[where_case],phi_surf_expanded.to_numpy()[where_case],temp_surf.to_numpy()[where_case])
    return psl 

 
test_pressure = pslec_piecewise(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))



1.73 s ± 26.2 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)




with Jit (currently cannot be jitted)

In [13]:
%%timeit -r 10

pressure_surf = 0.0065     #Kelvin per meter
GRAV_CONST = 9.80616    #Meters per second per second
SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin
ALPHA_0 = pressure_surf*SPEC_GAS_CONST/GRAV_CONST

@vectorize([float64(float64,float64,float64)])
def psl_hot_low(ps,phi_s,T_star): 
    combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star 
    return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
@vectorize([float64(float64,float64,float64)])
def psl_hot_high(ps,phi_s,T_star): 
    T_star_modified = 1/2*(290.5+T_star) 
    return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star_modified)
@vectorize([float64(float64,float64,float64)])
def psl_mild_high(ps,phi_s,T_star):
    combo_term = 290.5-T_star
    return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star*(1-1/2*(combo_term)+1/3*(combo_term)**2))
@vectorize([float64(float64,float64,float64)])
def psl_cold_low(ps,phi_s,T_star): 
    T_star_modified = 1/2*(255+T_star)
    combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star_modified 
    return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
#Check this function - it may be wrong in the ncl code! 
@vectorize([float64(float64,float64,float64)])
def psl_cold_high(ps,phi_s,T_star): 
    alpha = SPEC_GAS_CONST/phi_s*(290.5-T_star)
    T_star_modified = 1/2*(255+T_star)
    combo_term = alpha*phi_s/SPEC_GAS_CONST/T_star_modified 
    return ps*np.exp(combo_term/alpha*(1-1/2*(combo_term)+1/3*(combo_term)**2))    
@vectorize([float64(float64,float64,float64)])
def psl_near_zero(ps,phi_s,T_star): return ps


@njit
def pslec_piecewise_jit(temp_bottom, phi_surf, pressure_surf, pressure_bot): 
    #Based on the NCAR Technical Note "Vertical Interpolation and Truncation of Model-Coordinate Data"
    #By Trenberth, Berry, Buja; Dec 1993 
    #temp_surf = T_*, temp_bottom = T_NL, pressure_surf = p_s, pressure_bot = p_NL, phi_surf = \phi_s

    #coordinate check 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    temp_surf = temp_bottom*(1 + ALPHA_0*(pressure_surf/pressure_bot - 1)) #3b.5
    temp_bot_lapse = temp_surf + LAPSE_RATE*phi_surf/GRAV_CONST #denoted T_0 in doc, 3b.13

    is_near_zero =  np.absolute(phi_surf/GRAV_CONST) < 1e-4 #this has precedence over the others 
    is_hot_low =    np.logical_and(temp_surf >= 255,temp_bot_lapse <= 290.5, where = ~is_near_zero)  
    is_hot_high =   np.logical_and(temp_surf > 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_mild_high =  np.logical_and(temp_surf <= 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_cold_low =   np.logical_and(temp_surf < 255, temp_bot_lapse <= 290.5, where = ~is_near_zero) 
    is_cold_high =  np.logical_and(temp_surf < 255, temp_bot_lapse > 290.5, where = ~is_near_zero)  

    #create vectorized functions 

    psl = np.full_like(pressure_surf,np.nan)

    psl[is_near_zero] = psl_near_zero(pressure_surf[is_near_zero],phi_surf[is_near_zero],temp_surf[is_near_zero])
    psl[is_hot_low] = psl_hot_low(pressure_surf[is_hot_low],phi_surf[is_hot_low],temp_surf[is_hot_low])
    psl[is_hot_high ] = psl_hot_high(pressure_surf[is_hot_high ],phi_surf[is_hot_high ],temp_surf[is_hot_high ])
    psl[is_mild_high] = psl_mild_high(pressure_surf[is_mild_high],phi_surf[is_mild_high],temp_surf[is_mild_high])
    psl[is_cold_low] = psl_cold_low(pressure_surf[is_cold_low],phi_surf[is_cold_low],temp_surf[is_cold_low])
    psl[is_cold_high] = psl_cold_high(pressure_surf[is_cold_high],phi_surf[is_cold_high],temp_surf[is_cold_high])
    
    return psl 



temp_bc, phi_surf_bc, surf_pressure_bc, P_hybrid_bc = xr.broadcast(temp.isel(lev=-1),phi_surf,surf_pressure, P_hybrid.isel(lev=-1))
test_pressure = pslec_piecewise_jit(temp_bc.values, phi_surf_bc.values, surf_pressure_bc.values, P_hybrid_bc.values)

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=6)
           2	LOAD_DEREF(arg=0, lineno=8)
           4	LOAD_FAST(arg=1, lineno=8)
           6	BINARY_MULTIPLY(arg=None, lineno=8)
           8	LOAD_DEREF(arg=1, lineno=8)
          10	BINARY_TRUE_DIVIDE(arg=None, lineno=8)
          12	LOAD_FAST(arg=2, lineno=8)
          14	BINARY_TRUE_DIVIDE(arg=None, lineno=8)
          16	STORE_FAST(arg=3, lineno=8)
          18	LOAD_FAST(arg=0, lineno=9)
          20	LOAD_GLOBAL(arg=0, lineno=9)
          22	LOAD_METHOD(arg=1, lineno=9)
          24	LOAD_FAST(arg=3, lineno=9)
          26	LOAD_DEREF(arg=0, lineno=9)
          28	BINARY_TRUE_DIVIDE(arg=None, lineno=9)
          30	LOAD_CONST(arg=1, lineno=9)
          32	LOAD_CONST(arg=2, lineno=9)
          34	LOAD_FAST(arg=3, lineno=9)
          36	BINARY_MULTIPLY(arg=None, lineno=9)
          38	BINARY_SUBTRACT(arg=None, lineno=9)
          40	LOAD_CONST(arg=3, lineno=9)
          42	LOAD_FAST(arg=3, lineno=9)
        

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mNo implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float32, 3d, C), array(bool, 3d, C))
 
There are 22 candidate implementations:
[1m      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float32, 3d, C), array(bool, 3d, C))':[0m
[1m       No match.[0m
[1m      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 209.
        With argument(s): '(array(float32, 3d, C), array(bool, 3d, C))':[0m
[1m       Rejected as the implementation raised a specific error:
         NumbaTypeError: [1mMulti-dimensional indices are not supported.[0m[0m
  raised from /glade/work/wukenton/conda-envs/python_correct/lib/python3.9/site-packages/numba/core/typing/arraydecl.py:89
[0m
[0m[1mDuring: typing of intrinsic-call at <magic-timeit> (62)[0m
[1m
File "<magic-timeit>", line 62:[0m
[1m<source missing, REPL/exec in use?>[0m


In [14]:

def psl_ecmwf(temp_bottom: xr.DataArray, phi_surf: xr.DataArray, pressure_surf: xr.DataArray, pressure_bot: xr.DataArray): 
    #Based on the NCAR Technical Note "Vertical Interpolation and Truncation of Model-Coordinate Data"
    #By Trenberth, Berry, Buja; Dec 1993 
    #temp_surf = T_*, temp_bottom = T_NL, pressure_surf = p_s, pressure_bot = p_NL, phi_surf = \phi_s

    #coordinate check 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    phi_surf_expanded = phi_surf.broadcast_like(pressure_surf)
    temp_surf = temp_bottom*(1 + ALPHA_0*(pressure_surf/pressure_bot - 1)) #3b.5
    temp_bot_lapse = temp_surf + LAPSE_RATE*phi_surf/GRAV_CONST #denoted T_0 in doc, 3b.13

    is_near_zero =  np.absolute(phi_surf_expanded/GRAV_CONST) < 1e-4 #this has precedence over the others 
    is_hot_low =    np.logical_and(temp_surf >= 255,temp_bot_lapse <= 290.5, where = ~is_near_zero)  
    is_hot_high =   np.logical_and(temp_surf > 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_mild_high =  np.logical_and(temp_surf <= 290.5, temp_bot_lapse > 290.5, where = ~is_near_zero) 
    is_cold_low =   np.logical_and(temp_surf < 255, temp_bot_lapse <= 290.5, where = ~is_near_zero) 
    is_cold_high =  np.logical_and(temp_surf < 255, temp_bot_lapse > 290.5, where = ~is_near_zero)  

    #create vectorized functions 
    def psl_hot_low(ps,phi_s,T_star): 
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star 
        return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    def psl_hot_high(ps,phi_s,T_star): 
        T_star_modified = 1/2*(290.5+T_star) 
        return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star_modified)
    def psl_mild_high(ps,phi_s,T_star):
        combo_term = 290.5-T_star
        return ps*np.exp(phi_s/SPEC_GAS_CONST/T_star*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    def psl_cold_low(ps,phi_s,T_star): 
        T_star_modified = 1/2*(255+T_star)
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star_modified 
        return ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    #Check this function - it may be wrong in the ncl code! 
    def psl_cold_high(ps,phi_s,T_star): 
        alpha = SPEC_GAS_CONST/phi_s*(290.5-T_star)
        T_star_modified = 1/2*(255+T_star)
        combo_term = alpha*phi_s/SPEC_GAS_CONST/T_star_modified 
        return ps*np.exp(combo_term/alpha*(1-1/2*(combo_term)+1/3*(combo_term)**2))    
    def psl_near_zero(ps,phi_s,T_star): return ps

    formulas =  [psl_near_zero,  psl_hot_low,    psl_hot_high,   psl_mild_high,  psl_cold_low,   psl_cold_high]
    cases =     [is_near_zero,   is_hot_low,     is_hot_high,    is_mild_high,   is_cold_low,    is_cold_high]

    try: 
        assert np.logical_xor.reduce(cases,axis =0).all()
    except AssertionError: 
        assert np.any(cases,axis = 0).all(), "Underlap in cases"

    psl = np.full_like(pressure_surf,np.nan)
    for where_case, formula in zip(cases,formulas): 
            psl[where_case] = formula(pressure_surf.to_numpy()[where_case],phi_surf_expanded.to_numpy()[where_case],temp_surf.to_numpy()[where_case])
    return psl 


 
test_pressure = psl_ecmwf(temp.isel(lev=-1), phi_surf, surf_pressure, pressure_on_hybrid_ccm(surf_pressure,hyam,hybm).isel(lev=-1)

SyntaxError: unexpected EOF while parsing (2828060429.py, line 62)

Atomic execution with vectorize

In [20]:
%%timeit -r 10

@vectorize([float64(float64,float64,float64,float64)],nopython=True)
def pslec_atomic(temp_bot,phi_s,ps,pressure_bot): 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    temp_surf = temp_bot*(1 + ALPHA_0*(ps/pressure_bot - 1)) #3b.5
    temp_bot_lapse = temp_surf + LAPSE_RATE*phi_s/GRAV_CONST #denoted T_0 in doc, 3b.13

    #These cases are partitions - there is no overlap in cases here. 
    if abs(phi_s/GRAV_CONST) < 1e-4: 
        psl = ps
    elif temp_surf >= 255 and temp_bot_lapse <= 290.5: 
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/temp_surf 
        psl =  ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf > 290.5 and temp_bot_lapse > 290.5: 
        T_star_modified = 1/2*(290.5+temp_surf) 
        psl = ps*np.exp(phi_s/SPEC_GAS_CONST/T_star_modified)
    elif temp_surf >=255 and temp_surf <= 290.5 and temp_bot_lapse > 290.5: 
        combo_term = 290.5-temp_surf
        psl = ps*np.exp(phi_s/SPEC_GAS_CONST/temp_surf*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf < 255 and temp_bot_lapse <= 290.5: 
        T_star_modified = 1/2*(255+temp_surf)
        combo_term = ALPHA_0*phi_s/SPEC_GAS_CONST/T_star_modified 
        psl =  ps*np.exp(combo_term/ALPHA_0*(1-1/2*(combo_term)+1/3*(combo_term)**2))
    elif temp_surf < 255 and temp_bot_lapse > 290.5: 
        alpha = SPEC_GAS_CONST/phi_s*(290.5-temp_surf)
        T_star_modified = 1/2*(255+temp_surf)
        combo_term = alpha*phi_s/SPEC_GAS_CONST/T_star_modified 
        psl = ps*np.exp(combo_term/alpha*(1-1/2*(combo_term)+1/3*(combo_term)**2))    
    
    return psl 
    
def pslec(temp_bottom: xr.DataArray, phi_surf: xr.DataArray, pressure_surf: xr.DataArray, pressure_bot: xr.DataArray): 
    return xr.apply_ufunc(pslec_atomic,temp_bottom,phi_surf,pressure_surf,pressure_bot)


test_pressure = pslec(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1)
           2	LOAD_CONST(arg=1, lineno=4)
           4	STORE_FAST(arg=4, lineno=4)
           6	LOAD_CONST(arg=2, lineno=5)
           8	STORE_FAST(arg=5, lineno=5)
          10	LOAD_CONST(arg=3, lineno=6)
          12	STORE_FAST(arg=6, lineno=6)
          14	LOAD_FAST(arg=4, lineno=8)
          16	LOAD_FAST(arg=6, lineno=8)
          18	BINARY_MULTIPLY(arg=None, lineno=8)
          20	LOAD_FAST(arg=5, lineno=8)
          22	BINARY_TRUE_DIVIDE(arg=None, lineno=8)
          24	STORE_FAST(arg=7, lineno=8)
          26	LOAD_FAST(arg=0, lineno=10)
          28	LOAD_CONST(arg=4, lineno=10)
          30	LOAD_FAST(arg=7, lineno=10)
          32	LOAD_FAST(arg=2, lineno=10)
          34	LOAD_FAST(arg=3, lineno=10)
          36	BINARY_TRUE_DIVIDE(arg=None, lineno=10)
          38	LOAD_CONST(arg=4, lineno=10)
          40	BINARY_SUBTRACT(arg=None, lineno=10)
          42	BINARY_MULTIPLY(arg=None, lineno=10)
          44	

1.68 s ± 77.5 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [15]:
%%timeit -r 20

@vectorize([float64(float64,float64,float64,float64)],nopython=True)
def pslec_atomic(temp_bot,phi_s,ps,pressure_bot): 

    LAPSE_RATE = 0.0065     #Kelvin per meter
    GRAV_CONST = 9.80616    #Meters per second per second
    SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

    ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
    tstar = temp_bot*(1 + ALPHA_0*(ps/pressure_bot - 1)) #3b.5
    T0 = tstar + LAPSE_RATE*phi_s/GRAV_CONST #denoted T_0 in doc, 3b.13

    if abs(phi_s/GRAV_CONST) < 1e-4:
            psl = ps
    else: 
            tstar = temp_bot * (1+ALPHA_0 * (ps/pressure_bot-1))
            T0 = tstar + LAPSE_RATE*phi_s/GRAV_CONST
            if tstar <= 290.5 and T0 > 290.5: 
               ALPHA = SPEC_GAS_CONST/phi_s * (290.5 - tstar)
            elif tstar > 290.5  and T0 > 290.5: 
               ALPHA = 0
               tstar = 0.5* (290.5+tstar)
            else: 
               ALPHA = ALPHA_0
            if tstar < 255: 
               tstar = 0.5* (255+tstar)
            BETA = phi_s/ (SPEC_GAS_CONST*tstar)
            psl = ps * np.exp(BETA* (1-ALPHA*BETA/2+((ALPHA*BETA)**2)/3))
    
    return psl 
    
def pslec(temp_bottom: xr.DataArray, phi_surf: xr.DataArray, pressure_surf: xr.DataArray, pressure_bot: xr.DataArray): 
    return xr.apply_ufunc(pslec_atomic,temp_bottom,phi_surf,pressure_surf,pressure_bot)


test_pressure = pslec(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))

  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data = func(*input_data)
  result_data 

825 ms ± 26.8 ms per loop (mean ± std. dev. of 20 runs, 1 loop each)


  result_data = func(*input_data)
  result_data = func(*input_data)


Partial where with no jit

In [16]:
#%%timeit -r 10

def _pslec(t_bot, phi_sfc, ps, pressure_bot):
    R_d = 287.04  # dry air gas constant
    g_inv = 1 / 9.80616  # inverse of gravity
    alpha = 0.0065 * R_d * g_inv

    tstar = t_bot * (1 + alpha * (ps / pressure_bot - 1))
    hgt = phi_sfc * g_inv
    t0 = tstar + 0.0065 * hgt

    alph = xr.where((tstar <= 290.5) & (t0 > 290.5),
                    R_d / phi_sfc * (290.5 - tstar), alpha)

    alph = xr.where((tstar > 290.5) & (t0 > 290.5), 0, alph)
    tstar = xr.where((tstar > 290.5) & (t0 > 290.5), 0.5 * (290.5 + tstar),
                     tstar)

    tstar = xr.where((tstar < 255), 0.5 * (tstar + 255), tstar)

    beta = phi_sfc/R_d/tstar
    return ps*np.exp(beta*(1-alph*beta/2+(alph*beta)**2/3))

test_pressure = _pslec(temp.isel(lev=-1), phi_surf, surf_pressure, P_hybrid.isel(lev=-1))


Where with jit

In [17]:
#%%timeit -r 10

@njit(nogil=True)
def _pslec_jit(t_bot: np.array , phi_sfc: np.array , ps: np.array, pressure_bot: np.array):
    R_d = 287.04  # dry air gas constant
    g_inv = 1 / 9.80616  # inverse of gravity
    alpha = 0.0065 * R_d * g_inv

    tstar = t_bot * (1 + alpha * (ps / pressure_bot - 1))
    hgt = phi_sfc * g_inv
    t0 = tstar + 0.0065 * hgt

    alph = np.where((tstar <= 290.5) & (t0 > 290.5),
                    R_d / phi_sfc * (290.5 - tstar), alpha)

    alph = np.where((tstar > 290.5) & (t0 > 290.5), 0, alph)
    tstar = np.where((tstar > 290.5) & (t0 > 290.5), 0.5 * (290.5 + tstar),
                     tstar)

    tstar = np.where((tstar < 255), 0.5 * (tstar + 255), tstar)

    beta = phi_sfc/R_d/tstar
    return ps*np.exp(beta*(1-alph*beta/2+(alph*beta)**2/3))

temp_bc, phi_surf_bc, surf_pressure_bc, P_hybrid_bc = xr.broadcast(temp.isel(lev=-1),phi_surf,surf_pressure, P_hybrid.isel(lev=-1))
test_pressure = _pslec_jit(temp_bc.values, phi_surf_bc.values, surf_pressure_bc.values, P_hybrid_bc.values)

In [18]:
test_temp = temp.isel(lat=0,lon=0,time=0,lev=-1).values #* np.ones((2,2))
test_phi = phi_surf.isel(lon=0,lat=0).values #* np.ones((2,2))
test_surf = surf_pressure.isel(lon=0,lat=0,time=0).values# * np.ones((2,2))
test_hybrid = P_hybrid.isel(lat=0,lon=0,lev=-1,time=0).values# * np.ones((2,2))


print(test_temp)
print(test_phi)
print(test_surf)
print(test_hybrid)

test_pressure_atomic = pslec_atomic(test_temp, test_phi, test_surf, test_hybrid)
test_pressure_xr = pslec(test_temp,test_phi,test_surf,test_hybrid) 

print(test_pressure_atomic)
print(test_pressure_xr)
print(test_pressure.isel(lat=0,lon=0))


246.69228
27701.627573972324
68952.78
68439.5
99933.59998192343
99933.59998192343


AttributeError: 'numpy.ndarray' object has no attribute 'isel'

For loop with no jit

In [None]:
def pslec_for(T_bot,phi_s,ps,pressure_bot): 

   LAPSE_RATE = 0.0065     #Kelvin per meter
   GRAV_CONST = 9.80616    #Meters per second per second
   SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

   ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
   psl = xr.zeros_like(ps)
   lats = pressure_bot.coords['lat']
   lons = pressure_bot.coords['lon']
   for lati in lats: 
      for long in lons: 
         if abs(phi_s.sel(lat=lati,lon=long)/GRAV_CONST):
            psl.loc[dict(lat=lati,lon=long)] = ps.sel(lat=lati,lon=long)
         else: 
            tstar = T_bot.sel(lat=lati,lon=long) * (1+ALPHA_0) * (ps.sel(lat=lati,lon=long)/pressure_bot.sel(lat=lati,lon=long)-1)
            T0 = tstar + LAPSE_RATE*phi_s.sel(lat=lati,lon=long)/GRAV_CONST
            if tstar <= 290.5 and T0 > 290.5: 
               ALPHA = SPEC_GAS_CONST/phi_s.sel(lat=lati,lon=long) * (290.5 - tstar)
            elif tstar > 290.5  and T0 > 290.5: 
               ALPHA = 0
               tstar = 0.5* (290.5+tstar)
            else: 
               ALPHA = ALPHA_0
            if tstar < 255: 
               tstar = 0.5* (255+tstar)
            BETA = phi_s.sel(lat=lati,lon=long)/ (SPEC_GAS_CONST*tstar)
            psl.loc[dict(lat=lati,lon=long)] = ps.sel(lat=lati,lon=long) * np.exp(BETA* (1-ALPHA*BETA/2+((ALPHA*BETA)**2)/3))
         
   return psl 

test_pressure = pslec_for(temp.isel(lev=-1,time=0), phi_surf, surf_pressure.isel(time=0), P_hybrid.isel(lev=-1,time=0))

For loop with jit

In [22]:
%%timeit -r 10

@njit(nogil=True)
def pslec_for_jit(T_bot,phi_s,ps,pressure_bot): 

   LAPSE_RATE = 0.0065     #Kelvin per meter
   GRAV_CONST = 9.80616    #Meters per second per second
   SPEC_GAS_CONST = 287.04 #Joules per kilogram per Kelvin

   ALPHA_0 = LAPSE_RATE*SPEC_GAS_CONST/GRAV_CONST
    
   ntim = ps.shape[0]
   nlat = ps.shape[1]
   nlon = ps.shape[2]
   psl = np.zeros_like(ps)
   for t in range(ntim):
      for lati in range(nlat): 
         for long in range(nlon): 
            if abs(phi_s[t,lati,long]/GRAV_CONST) < 1e-4:
               psl[t,lati,long] = ps[t,lati,long]
            else: 
               tstar = T_bot[t,lati,long] * (1+ALPHA_0 * (ps[t,lati,long]/pressure_bot[t,lati,long]-1))
               T0 = tstar + LAPSE_RATE*phi_s[t,lati,long]/GRAV_CONST
               if tstar <= 290.5 and T0 > 290.5: 
                  ALPHA = SPEC_GAS_CONST/phi_s[t,lati,long] * (290.5 - tstar)
               elif tstar > 290.5  and T0 > 290.5: 
                  ALPHA = 0
                  tstar = 0.5* (290.5+tstar)
               else: 
                  ALPHA = ALPHA_0
               if tstar < 255: 
                  tstar = 0.5* (255+tstar)
               BETA = phi_s[t,lati,long]/ (SPEC_GAS_CONST*tstar)
               psl[t,lati,long] = ps[t,lati,long] * np.exp(BETA* (1-ALPHA*BETA/2+((ALPHA*BETA)**2)/3))
            
   return psl 

temp_bc, phi_surf_bc, surf_pressure_bc, P_hybrid_bc = xr.broadcast(temp.isel(lev=-1),phi_surf,surf_pressure, P_hybrid.isel(lev=-1))
test_pressure = pslec_for_jit(temp_bc.values, phi_surf_bc.values, surf_pressure_bc.values, P_hybrid_bc.values)
test_pressure = surf_pressure.copy(data=test_pressure)

840 ms ± 7.08 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


CZ2     


In [13]:
#@guvectorize([float64[:,:,:](float64[:,:],float64[:,:],float64[:,:,:],float64[:],float64[:],float64[:],float64[:],float64)],'(i,j),(i,j),(i,j,k),(k),(k),(k),(k),()->(i,j,k)', nopython=True)
@njit 
def cz2ccm(pressure_surf, phi_surf, temp_virtual, hyam, hybm, hyai, hybi, ref_pressure = 100000): 
    #This code mirrors the cz2ccm NCL function/DCz2 Fortran function (cz2ccm_dp.f) 
    #Loop formulas were attempted to be kept the same, with index conversion being done in the ranges
    # Parameters
    R = 287.04; G0 = 9.80616; RBYG = R / G0

    nlon, nlat, nlev = temp_virtual.shape 
    
    pmln = np.zeros((nlon,nlev+1))
    pterm = np.zeros((nlon,nlev))
    z2_slice = np.zeros((nlon,nlev))
    z2 = np.zeros((nlon,nlat,nlev)) 
    hyam_pad = np.zeros(nlev+1)
    hybm_pad = np.zeros(nlev+1)

    hyam_pad[1:] = hyam
    hybm_pad[1:] = hybm

    # Compute intermediate quantities using scratch space
    for j in range(nlat): 
        for i in range(nlon):
            pmln[i, 0] = np.log(ref_pressure * hyam_pad[nlev-1] + pressure_surf[i, j] * hybi[nlev-1])
            pmln[i, nlev] = np.log(ref_pressure * hyam_pad[0] + pressure_surf[i, j] * hybi[0])

        # Invert vertical loop
        for k in range(nlev, -1, -1):
            for i in range(nlon):
                ARG = ref_pressure * hyam_pad[ nlev - k] + pressure_surf[i, j] * hybm_pad[nlev - k]
                if ARG > 0.0:
                    pmln[i, k] = np.log(ARG)
                else:
                    pmln[i, k] = 0.0

        # Initialize z2 to sum of ground height and thickness of top half-layer
        for k in range(1,nlev-1):
            for i in range(nlon):
                pterm[i, k] = RBYG * temp_virtual[i, j, k] * 0.5 * (pmln[i, k+1] - pmln[i, k - 1])

        for k in range(nlev - 1):
            for i in range(nlon):
                z2_slice[i, k] = phi_surf[i,j] / G0 + RBYG * temp_virtual[i, j, k] * 0.5 * (pmln[i, k + 1] - pmln[i, k])

        # Eq 3.a.109.5 where l=K, k=K
        K = nlev-1
        for i in range(nlon):
            z2_slice[i, K] = phi_surf[i,j] / G0 + RBYG * temp_virtual[i, j, K] * (np.log(pressure_surf[i, j] * hybi[ 0]) - pmln[i, K])

        # Eq 3.a.109.4 where l=K, k<K
        for k in range(nlev-1):
            L = nlev-1
            for i in range(nlon):
                z2_slice[i, k] += RBYG * temp_virtual[i, j, L] * (np.log(pressure_surf[i, j] * hybi[ 0]) - 0.5 * (pmln[i, L -1] + pmln[i, L]))

        # Add thickness of the remaining full layers (i.e., integrate from ground to highest layer interface)
        for k in range(nlev - 2):
            for L in range(k, nlev-1):
                for i in range(nlon):
                    z2_slice[i, k] += pterm[i, L]
        z2[:,j,:] = z2_slice 
    return z2

#def wrap_cz2ccm(pressure_surf, phi_surf, temp_virtual, hyam, hybm, hyai, hybi, ref_pressure = 100000): 
#    return xr.apply_ufunc(cz2ccm,pressure_surf, phi_surf, temp_virtual, hyam, hybm, hyai, hybi, ref_pressure, input_core_dims=[[],[],['lon','lat','lev'],[],[],[],[],[]])

#cz2 = wrap_cz2ccm(surf_pressure, phi_surf, temp, hyam, hybm, hyai, hybi)

virtual_temp = temp * (1+0.61*in_hus['Q'])
cz2 = cz2ccm(surf_pressure.isel(time=0).values,phi_surf.values,virtual_temp.transpose(...,'lon','lat','lev').isel(time=0).values,hyam.values[::-1],hybm.values[::-1],hyai.values[::-1],hybi.values[::-1])
print(cz2[0,0,0])

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=2)
           2	LOAD_CONST(arg=1, lineno=7)
           4	STORE_FAST(arg=8, lineno=7)
           6	LOAD_CONST(arg=2, lineno=7)
           8	STORE_FAST(arg=9, lineno=7)
          10	LOAD_FAST(arg=8, lineno=7)
          12	LOAD_FAST(arg=9, lineno=7)
          14	BINARY_TRUE_DIVIDE(arg=None, lineno=7)
          16	STORE_FAST(arg=10, lineno=7)
          18	LOAD_FAST(arg=2, lineno=9)
          20	LOAD_ATTR(arg=0, lineno=9)
          22	UNPACK_SEQUENCE(arg=3, lineno=9)
          24	STORE_FAST(arg=11, lineno=9)
          26	STORE_FAST(arg=12, lineno=9)
          28	STORE_FAST(arg=13, lineno=9)
          30	LOAD_GLOBAL(arg=1, lineno=11)
          32	LOAD_METHOD(arg=2, lineno=11)
          34	LOAD_FAST(arg=11, lineno=11)
          36	LOAD_FAST(arg=13, lineno=11)
          38	LOAD_CONST(arg=3, lineno=11)
          40	BINARY_ADD(arg=None, lineno=11)
          42	BUILD_TUPLE(arg=2, lineno=11)
          44	CALL_METHOD(arg=1, 

39353.190125429915


In [12]:
ncl_zccm = xr.open_dataarray('ncl_zccm.nc')
print(ncl_zccm[0,0,0,0].values)

2878.8882


In [None]:

def wrap_cz2ccm(pressure_surf, phi_surf, temp_virtual, hyam, hybm, hyai, hybi, ref_pressure = 100000): 
    return xr.apply_ufunc(cz2ccm,pressure_surf, phi_surf, temp_virtual, hyam, hybm, hyai, hybi, ref_pressure, input_core_dims=[[],[],['lon','lat','lev'],[],[],[],[],[]])

cz2 = wrap_cz2ccm(surf_pressure, phi_surf, temp, hyam, hybm, hyai, hybi)

In [57]:
ncl_pslec = xr.open_dataarray('ncl_pslec.nc')

print(test_pressure.isel(lat=0,lon=0,time=0))
#print(ncl_pslec.isel(lat=0,lon=0).values)
print(test_pressure.isel(time=0) - ncl_pslec)
print(np.max((test_pressure.isel(time = 0) - ncl_pslec)))


<xarray.DataArray 'PS' ()>
array(99933.6, dtype=float32)
Coordinates:
    lat      float64 -90.0
    lon      float64 0.0
    time     object 2006-01-01 00:00:00
Attributes:
    units:      Pa
    long_name:  Surface pressure
<xarray.DataArray 'PS' (lat: 192, lon: 288)>
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
Coordinates:
  * lat      (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 87.17 88.12 89.06 90.0
  * lon      (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
    time     object 2006-01-01 00:00:00
<xarray.DataArray 'PS' ()>
array(0.0078125)
Coordinates:
    time     object 2006-01-01 00:00:00


In [23]:
#Interpolate to Pressure Coordinates ----------------------------------------------------------------------------------
#logging.info("Interpolating variables to pressure coordinates...")

#if current_log_level == 10: print(temp); print(surf_pressure); print(in_zsfc['PHIS']); 
hyam = in_ta.cf['hyam'] 
hybm = in_ta.cf['hybm']
hyai = in_ta.cf['hyai']
hybi = in_ta.cf['hybi']
default_levs = np.array([1000.0, 975.0, 950.0, 925.0, 900.0, 850.0, 800.0, 750.0, 700.0, 650.0, 600.0, 550.0, 500.0, \
             450.0, 400.0, 350.0, 300.0, 250.0, 200.0, 150.0, 100.0, 70.0, 50.0, 30.0, 20.0, 10.0 ])


surf_pressure = in_ps.cf['PS']
phi_surf = in_zsfc['PHIS']
temp = in_ta["T"]    
fixed_phi_sfc = in_zsfc['PHIS']  
fixed_phi_sfc.coords['lat'] = surf_pressure.coords['lat']
fixed_phi_sfc.coords['lon'] = surf_pressure.coords['lon'] 

temp_interp = gcomp.interpolation.interp_hybrid_to_pressure(temp.isel(time=0),surf_pressure.isel(time=0),hyam,hybm, 
                                                            new_levels=default_levs, 
                                                            lev_dim = 'lev', 
                                                            method='log',
                                                            extrapolate=True,
                                                            variable='temperature',
                                                            t_bot=temp.isel(lev=-1,time=0),
                                                            phi_sfc=fixed_phi_sfc)

print(temp_interp)

<xarray.DataArray 'T' (plev: 26, lat: 192, lon: 288)>
dask.array<setitem, shape=(26, 192, 288), dtype=float32, chunksize=(26, 192, 288), chunktype=numpy.ndarray>
Coordinates:
    time     object 2006-01-01 00:00:00
  * lat      (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 87.17 88.12 89.06 90.0
  * plev     (plev) float64 1e+03 975.0 950.0 925.0 ... 50.0 30.0 20.0 10.0
  * lon      (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
Attributes:
    units:      K
    long_name:  Temperature
