In [1]:
import xarray as xr
import geocat.comp as gc
import numpy as np

In [2]:
conda list geocat

# packages in environment at /glade/work/oero/miniconda3/envs/geocat-notebooks:
#
# Name                    Version                   Build  Channel
geocat-comp               2022.04.0          pyha770c72_0    conda-forge
geocat-datafiles          2022.03.0          pyha770c72_0    conda-forge
geocat-f2py               2021.04.0        py38h3d0eb6f_0    ncar
geocat-viz                2022.05.0          pyhd8ed1ab_0    conda-forge

Note: you may need to restart the kernel to use updated packages.


In [3]:
isccp_climo_filename = "ISCCP-Basic.HGG.GLOBAL.10KM.climo.198307-201706.nc"

nds = xr.open_dataset("/glade/work/brianpm/observations/isccp/climo/"+isccp_climo_filename)

x = nds['CLDTOT_ISCCP']

# 2° grid
olon = xr.DataArray(np.linspace(0, 358, 180), dims=["lon"])
olon.name = "lon"
olon.attrs["units"] = "degrees_east"
olat = xr.DataArray(np.linspace(-88, 88, 89), dims=["lat"])
olat.name = "lat"
olat.attrs["units"] = "degrees_north"

In [4]:
x

In [5]:
olon

In [6]:
olat

In [7]:
# interpolate with GeoCAT's linint2
xnew = gc.linint2(x, olon, olat, xi=x.lon, yi=x.lat, icycx=1)

TypeError: 'NoneType' object is not iterable

In [None]:
def linint_wrap(fi, xo, yo):
    """Wraps GeoCAT's linint2 function. Adds a cyclic point to ensure periodic longitude."""
    print("1")
    addpt = fi.isel(lon=0).copy()
    print("2")
    addpt['lon'] = fi.lon[-1]+(fi.lon[1]-fi.lon[0])
    print("3")
    xiwrap = xr.concat([fi, addpt], dim='lon')
    print("4")
    # if xiwrap.chunks is not None:
    #     print(f"xiwrap has chunks: {xiwrap.chunks}")
    # else:
    #     print("xiwrap is NOT chunked. --> chunk it")
    #     xiwrap = xiwrap.chunk({"lat":-1,"lon":-1})
    #     print(xiwrap.chunks)        
    return gc.linint2(xiwrap, xo, yo, xi=xiwrap.lon, yi=xiwrap.lat, icycx=1, msg_py=None)

In [None]:

# interpolated = linint_wrap( x, olon, olat)


In [None]:
x.values.dtype

In [None]:
oshape = (12, len(olat), len(olat))
from geocat.f2py.linint2_wrapper import _linint2 as f2pylinint2
# _linint2(xi, yi, fi, xo, yo, icycx, msg_py, shape):
nowrap = f2pylinint2(x['lon'].values, x['lat'].values, x[0,:,:].values, olon.values, olat.values, icycx=1, msg_py=None, shape=oshape)

In [None]:
#
# what if we did it with just numpy
#

mlon, mlat = np.meshgrid(olon, olat)

In [None]:
# problem: nan values mess up interpolation
# try: fill in nan values and later mask those region
xfill = x.interpolate_na(dim='lon')
nmissing = np.count_nonzero(np.isnan(xfill))
print(f"N missing = {nmissing}")
if nmissing > 0:
    print("do latitude next")
    xfill = xfill.interpolate_na(dim='lat')
nmissing = np.count_nonzero(np.isnan(xfill))
print(f"N missing = {nmissing}")
if nmissing > 0:
    print("do time last")
    xfill = xfill.interpolate_na(dim='time')
nmissing = np.count_nonzero(np.isnan(xfill))
print(f"N missing = {nmissing}")
if nmissing > 0:
    xfill = xfill.ffill(dim='lon')
    xfill = xfill.bfill(dim='lon')
nmissing = np.count_nonzero(np.isnan(xfill))
print(f"N missing = {nmissing}")
if nmissing > 0:
    xfill = xfill.ffill(dim='lat')
    xfill = xfill.bfill(dim='lat')
nmissing = np.count_nonzero(np.isnan(xfill))
print(f"N missing = {nmissing}")
if nmissing > 0:
    xfill = xfill.ffill(dim='time')
    xfill = xfill.bfill(dim='time')
    print(f"N missing = {nmissing}")

# to mask final field:
msgmsk = xr.where(np.isnan(x), 0, 1)
    

# from scipy.interpolate import interp2d
from scipy.interpolate import RectSphereBivariateSpline
colatitude = np.radians(x.lat) + (np.pi / 2)  # has to be colatitude (0, pi)
ifunc = RectSphereBivariateSpline(colatitude, np.radians(x.lon), xfill[0,:,:])
# apply to new level and latitude
new_lats, new_lons = np.meshgrid(np.radians(olat.values) + (np.pi / 2), np.radians(olon.values))
xnew = ifunc.ev(new_lats, new_lons).reshape((len(olon), len(olat))).T

ifunc = RectSphereBivariateSpline(colatitude, np.radians(x.lon), msgmsk[0,:,:])
mask = ifunc.ev(new_lats, new_lons).reshape((len(olon), len(olat))).T

# convert to DataArray

xnew = xr.DataArray(xnew, dims=("lat","lon"), coords={"lat":olat, "lon":olon})

mask =  xr.DataArray(mask, dims=("lat","lon"), coords={"lat":olat, "lon":olon})

In [None]:
import matplotlib.pyplot as plt

In [None]:
xnew_ma = xr.where(mask < 0.99, np.nan, xnew)

In [None]:
xnew_ma.plot.pcolormesh()

# here 
is one

## two 

another one


### three
check
### another three
checker