In [33]:
import logging
from numba import guvectorize
import numpy as np
import xarray as xr

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [2]:
ds = xr.open_mfdataset("../../data/deg5/mres_b.e10.B2000_CAM5.5deg.001.cam2.h0.2097-??.nc")

In [34]:
# define variables
PRES = 'P'    # 3D air pressure 
HYAM = 'hyam' # hybrid a coefficient (only used if P not present)
HYBM = 'hybm' # hybrid b coefficient (see HYAM)
PS = 'PS'     # surface pressure (see HYAM)
P0 = 'P0'     # reference pressure hybrid coordinate (see HYAM)
VDIM = 'lev'  # vertical dimension (will be interpolated)
XDIM = 'lon'  # longitudinal dimension (will be averaged over)
YDIM = 'lat'  # latitudinal dimension
TDIM = 'time' # time dimension


@guvectorize("(p), (p), (pi) -> (pi)", nopython=True)
def _vertinterp_pressure1d_gu(f, p, pi, out):
    """interpolate field f(p) to pi in ln(p) coordinates (p and pi must be ascending)"""
    i, imax, p0, f0 = 0, len(pi), p[0], f[0]
    while i < imax and pi[i] < p0:
        out[i] = np.nan      
        i = i + 1 
    for p1,f1 in zip(p[1:], f[1:]):
        while i < imax and pi[i] <= p1:
            out[i] = (f1-f0)/np.log(p1/p0)*np.log(pi[i]/p0)+f0
            i = i + 1
        p0, f0 = p1, f1
    while i < imax:
        out[i] = np.nan
        i = i + 1
    return

In [50]:
@guvectorize([
    "(float64[:,:], float64[:,:], float64[:], float64[:])",
    "(float32[:,:], float32[:,:], float64[:], float32[:])"
], "(p,x), (p,x), (pi) -> (pi)", nopython=None)
def _vertinterp_and_zonalmean_gu(f, p, pi, out):
    Np, Nx = f.shape
    Npi = len(pi)
    res = np.zeros((Npi,Nx), dtype=f.dtype)
    for x in range(Nx):
        _vertinterp_pressure1d_gu(f[:,x], p[:,x], pi, res[:,x])
    for pp in range(Npi):
        out[pp] = np.nanmean(res[pp,:])
    return


@eval_doc()
def zonalmean3d(ds, plevs):
    """Interpolate dataset to pressure levels and take zonal mean

    Consider chunking ds before calling this function.

    ds : xr.Dataset
        dataset to take 3D zonal mean of
        must contain {PRES} or {HYAM}, {HYBM}, {P0} and {PS},
        and dimensions {VDIM} and {XDIM}.
    plevs : Iterable[float]
        new pressure levels in hPa

    returns: xr.Dataset
        dataset zonally averaged at the specified pressure levels
    """

    # sort by vertical dimension and check units
    ds = ds.sortby(VDIM)
    if max(ds[VDIM]) < 1100:
        logger.info(f"max. value of {VDIM} is {max(ds[VDIM].values):.2f}, assuming hPa units")
        ds[VDIM] = ds[VDIM] * 100
    else:
        logger.info(f"max. value of {VDIM} is {max(ds[VDIM].values):.2f}, assuming Pa units")
    
    # calculate 3D pressure in Pa if needed
    if PRES not in ds:
        logger.info(f"calculating {PRES} from hybrid parameters")
        P = ds[HYAM] * ds[P0] + ds[HYBM] * ds[PS]
        # ds = ds.drop_vars((HYAM,HYBM))
    else: 
        P = ds[PRES]
    if P.max() < 1100:
        logger.warning(f"max. value of {PRES} is {P.max():.1f}, expecting Pa units")
    
    # create new pressure coordinate in hPa
    plev = xr.DataArray(
        data = np.array(sorted(plevs), dtype='float64'),
        dims = 'plev',
        name = 'plev',
        attrs = {'standard_name': 'air_pressure',
                 'long_name': 'air pressure',
                 'units': 'hPa'}
    )
    ds = ds.assign_coords({'plev': plev})

    # execute interpolation and zonal averaging
    ds = xr.apply_ufunc(
        _vertinterp_and_zonalmean_gu,  
        ds, 
        P, 
        plev*100,
        input_core_dims=[[VDIM,XDIM], [VDIM,XDIM], list(plev.dims)], 
        output_core_dims=[list(plev.dims)], 
        dask="parallelized",
        keep_attrs=True,
        on_missing_core_dim='copy',
    ).transpose('time',*plev.dims,..., missing_dims='ignore')
    
    return ds

In [37]:
dsi = zonalmean3d( ds[['hyam','hybm','P0','PS','Z3']], [250,900])
dsi

Unnamed: 0,Array,Chunk
Bytes,0.94 kiB,80 B
Shape,"(12, 10)","(1, 10)"
Dask graph,12 chunks in 38 graph layers,12 chunks in 38 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 0.94 kiB 80 B Shape (12, 10) (1, 10) Dask graph 12 chunks in 38 graph layers Data type float64 numpy.ndarray",10  12,

Unnamed: 0,Array,Chunk
Bytes,0.94 kiB,80 B
Shape,"(12, 10)","(1, 10)"
Dask graph,12 chunks in 38 graph layers,12 chunks in 38 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.94 kiB,80 B
Shape,"(12, 10)","(1, 10)"
Dask graph,12 chunks in 38 graph layers,12 chunks in 38 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 0.94 kiB 80 B Shape (12, 10) (1, 10) Dask graph 12 chunks in 38 graph layers Data type float64 numpy.ndarray",10  12,

Unnamed: 0,Array,Chunk
Bytes,0.94 kiB,80 B
Shape,"(12, 10)","(1, 10)"
Dask graph,12 chunks in 38 graph layers,12 chunks in 38 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,96 B,8 B
Shape,"(12,)","(1,)"
Dask graph,12 chunks in 1 graph layer,12 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 96 B 8 B Shape (12,) (1,) Dask graph 12 chunks in 1 graph layer Data type float64 numpy.ndarray",12  1,

Unnamed: 0,Array,Chunk
Bytes,96 B,8 B
Shape,"(12,)","(1,)"
Dask graph,12 chunks in 1 graph layer,12 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,106.03 kiB,232 B
Shape,"(12, 39, 58)","(1, 1, 58)"
Dask graph,468 chunks in 26 graph layers,468 chunks in 26 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 106.03 kiB 232 B Shape (12, 39, 58) (1, 1, 58) Dask graph 468 chunks in 26 graph layers Data type float32 numpy.ndarray",58  39  12,

Unnamed: 0,Array,Chunk
Bytes,106.03 kiB,232 B
Shape,"(12, 39, 58)","(1, 1, 58)"
Dask graph,468 chunks in 26 graph layers,468 chunks in 26 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.31 kiB,16 B
Shape,"(12, 2, 39)","(1, 2, 1)"
Dask graph,468 chunks in 147 graph layers,468 chunks in 147 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 7.31 kiB 16 B Shape (12, 2, 39) (1, 2, 1) Dask graph 468 chunks in 147 graph layers Data type float64 numpy.ndarray",39  2  12,

Unnamed: 0,Array,Chunk
Bytes,7.31 kiB,16 B
Shape,"(12, 2, 39)","(1, 2, 1)"
Dask graph,468 chunks in 147 graph layers,468 chunks in 147 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
