In [3]:
import xarray as xr
import grib2io
import pandas as pd
import datetime
from glob import glob
from tqdm.auto import tqdm
import numpy as np
import multiprocessing
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import matplotlib
import scipy
matplotlib.rcParams.update({
 "savefig.facecolor": "w",
 "figure.facecolor" : 'w',
 "figure.figsize" : (8,6),
 "text.color": "k",
 "legend.fontsize" : 20,
 "font.size" : 30,
 "axes.edgecolor": "k",
 "axes.labelcolor": "k",
 "axes.linewidth": 3,
 "xtick.color": "k",
 "ytick.color": "k",
 "xtick.labelsize" : 25,
 "ytick.labelsize" : 25,
 "ytick.major.size" : 12,
 "xtick.major.size" : 12,
 "ytick.major.width" : 2,
 "xtick.major.width" : 2,
 "font.family": 'STIXGeneral',
 "mathtext.fontset" : "cm"})
from mpl_toolkits.basemap import Basemap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from numba import njit

In [13]:
grib2io.__version__

'2.2.0'

In [14]:
numba.__version__

'0.59.1'

From my understanding, Numba's "just in time compiling" speeds up pure python code with a few simple (?) drop in decorators, allowing major speeds ups with numpy operations, for loops, etc. But that's the catch: pure python. Will this work nicely with grib2io interpolation / scipy minimization (both of which are build on Fortran, C, etc.)?

Below is an example of attempting to use numba to speed up some numpy calculations but gets hung up on the grib2io interp routine.

In [2]:
#---------------------#
#
# Original
#
#---------------------#
def evaluate_at_new_grid(new_grid, grid_shape, grid_def, ens_mem,min_lat_idx, min_lon_idx, c_ai):
    
    #interpolate ensemble member between displaced grid and gefs grid
    displaced_ens = grib2io.interpolate_to_stations(ens_mem, 'bilinear', grid_def, new_grid[1], new_grid[0])
    
    # interp_to_stations returns a list of points instead of the 2D field. just need to rearrange below
    coalesced_ensemble = np.zeros(np.shape(ens_mem))
    y,x = grid_shape
    for i in range(x):
        for j in range(y):
            point = i + x*j
            coalesced_ensemble[j+min_lat_idx, i+min_lon_idx] = displaced_ens[point] * (1. + c_ai[point])
    return coalesced_ensemble

In [11]:
#-----------------------#
#
# Numba version, @njit decorator
#
#-----------------------#

@njit()
def numba_eval_v1(new_grid, grid_shape, grid_def,ens_mem,min_lat_idx, min_lon_idx, c_ai):
    
    #interpolate ensemble member between displaced grid and gefs grid
    displaced_ens = grib2io.interpolate_to_stations(ens_mem, 'bilinear', grid_def, new_grid[1], new_grid[0])
    
    # interp_to_stations returns a list of points instead of the 2D field. just need to rearrange below
    coalesced_ensemble = np.zeros(np.shape(ens_mem))
    y,x = grid_shape
    for i in range(x):
        for j in range(y):
            point = i + x*j
            coalesced_ensemble[j+min_lat_idx, i+min_lon_idx] = displaced_ens[point] * (1. + c_ai[point])
    return coalesced_ensemble


In [5]:
#load data
date = '20231005'
lead_time = 120 #hours
dat_dir = f'/scratch2/STI/mdl-sti/Sidney.Lower/test_data/gefs/{date}/'
gefs_paths = sorted(glob(dat_dir+f'gefs*.t0z.f120'))
filters = dict(productDefinitionTemplateNumber=11, shortName='APCP')
gefs_data = xr.open_mfdataset(gefs_paths, chunks=None,engine='grib2io', filters=filters, 
                              parallel=False, concat_dim=[pd.Index(np.arange(len(gefs_paths)), name="member")],combine="nested")


In [6]:
def basis_functions(gefs_grid_points, m_k=None, n_k=None, xrange=(np.nan, np.nan), yrange=(np.nan, np.nan)):

    #transform grid points to Fourier space
    x = grid2fourier(gefs_grid_points[:,0], xrange)
    y = grid2fourier(gefs_grid_points[:,1], yrange)

    #calculate basis functions 
    b_ik = np.zeros((len(gefs_grid_points), len(m_k)))
    for mk in range(len(m_k)):
            b_ik[:, mk] = 2 * np.sin(m_k[mk] * gefs_grid_points[:,0]) * np.sin(n_k[mk] * gefs_grid_points[:,1])

    return b_ik

def grid2fourier(x, xrange):
    return (np.pi / np.diff(xrange)) * (x - xrange[0])

def basis_truncation(n=3):
    #select wavenumbers below truncation limit
    m_k = np.tile(np.arange(1, n+1),n)
    n_k = np.repeat(np.arange(1, n + 1), n)
    select = m_k**2 + n_k**2 <= n**2

    return m_k[select], n_k[select]

In [27]:
### set up

def find_nearest_lat_lon(lat_lon_range, lat_lon_arr):
    min_val, max_val = lat_lon_range[0], lat_lon_range[1]
    min_idx = (np.abs(np.array(lat_lon_arr)-min_val)).argmin()
    max_idx = (np.abs(np.array(lat_lon_arr)-max_val)).argmin()
    if min_idx > max_idx: #latitude
        trimmed_arr = lat_lon_arr[max_idx:min_idx+1]
        return trimmed_arr, max_idx, min_idx
    else:
        trimmed_arr = lat_lon_arr[min_idx:max_idx+1]
        return trimmed_arr, min_idx, max_idx

#--------------------------------------#
#
#  Get GEFS data: ensemble mean, ensemble members
#
#--------------------------------------#

ens_mean = gefs_data.APCP.mean(dim='member')
raw_mean = ens_mean.data.compute()
ens_members = gefs_data.APCP.data.compute()
ens_std = gefs_data.APCP.std(dim='member').data.compute()

lats = gefs_data.latitude.data.compute().T[0]
lons = gefs_data.longitude.data[0].compute()
y, min_lat, max_lat = find_nearest_lat_lon([10,60], lats)

lon_range = 180 - np.array([180, 55]) + 180 #convert from 0-360 to East / West
x, min_lon, max_lon = find_nearest_lat_lon(lon_range, lons)


m_k, n_k = basis_truncation(7) #truncating the basis functions B_k to wavenumbers < 7
gy, gx=np.meshgrid(y,x ,indexing='ij')
gefs_grid_points = np.reshape((gx, gy), (2, -1), order='C').T
#get displacement basis functions, transforms grid x,y to fourier space
b_ik = basis_functions(gefs_grid_points, m_k, n_k, [x[0], x[-1]], [y[0], y[-1]]
    )
d_o_f = len(m_k)
#basis function coefficients, the target of our minimization
c_k = {"x": np.repeat(0.1, d_o_f), "y": np.repeat(-0.1, d_o_f), 'a': np.random.uniform(0.5, 2.,size=d_o_f)}
grid_def = gefs_data.grib2io.griddef()
grid_shape = np.shape(gx)
c_xi = np.dot(b_ik, c_k['x'])
c_yi = np.dot(b_ik, c_k['y'])
n_x_i = gefs_grid_points[:,0] + c_xi
n_y_i = gefs_grid_points[:,1] + c_yi

#amplitude displacements
c_ai = np.dot(b_ik, c_k['a'])
ens_mem = ens_members[10]

In [12]:
test_numba  = numba_eval_v1((n_x_i, n_y_i), grid_shape, grid_def,ens_mem, min_lat,min_lon, c_ai)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mnon-precise type pyobject[0m
[0m[1mDuring: typing of argument at /tmp/ipykernel_653795/4097335957.py (7)[0m
[1m
File "../../../../../tmp/ipykernel_653795/4097335957.py", line 7:[0m
[1m<source missing, REPL/exec in use?>[0m 

This error may have been caused by the following argument(s):
- argument 2: [1mCannot determine Numba type of <class 'grib2io._grib2io.Grib2GridDef'>[0m


So let's try a second version, where we hand the evaluation function the griddef func?

In [15]:
#-----------------------#
#
# Numba version, @njit decorator
#
#-----------------------#

@njit()
def numba_eval_v2(new_grid, grid_shape, grid_def_func,grid_def_arr,ens_mem,min_lat_idx, min_lon_idx, c_ai):

    grid_def = grid_def_func(grid_def_arr)
    
    #interpolate ensemble member between displaced grid and gefs grid
    displaced_ens = grib2io.interpolate_to_stations(ens_mem, 'bilinear', grid_def, new_grid[1], new_grid[0])
    
    # interp_to_stations returns a list of points instead of the 2D field. just need to rearrange below
    coalesced_ensemble = np.zeros(np.shape(ens_mem))
    y,x = grid_shape
    for i in range(x):
        for j in range(y):
            point = i + x*j
            coalesced_ensemble[j+min_lat_idx, i+min_lon_idx] = displaced_ens[point] * (1. + c_ai[point])
    return coalesced_ensemble


In [16]:
from grib2io import Grib2GridDef

grid_def_func = Grib2GridDef.from_section3
grid_def_arr = gefs_data.APCP.GRIB2IO_section3

test_numba  = numba_eval_v2((n_x_i, n_y_i), grid_shape, grid_def_func,grid_def_arr ,ens_mem, min_lat,min_lon, c_ai)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mnon-precise type pyobject[0m
[0m[1mDuring: typing of argument at /tmp/ipykernel_653795/3884732420.py (7)[0m
[1m
File "../../../../../tmp/ipykernel_653795/3884732420.py", line 7:[0m
[1m<source missing, REPL/exec in use?>[0m 

This error may have been caused by the following argument(s):
- argument 2: [1mCannot determine Numba type of <class 'method'>[0m


Again, doesn't like non array/float/int args...
Even still, the interpolation that grib2io performs is not really the computational bottleneck. It's the minimization itself. 

I've looked into other packages for fast minimization/jit capabilities/numpy speed ups (JAX, interpax). JAX does fast differentiation of numpy arrays so could be a good solution to the scipy minimization bottleneck (where its approximating the objective function gradient potentially hundreds of times) but even this struggles a bit with the inner workings of grib2io and scipy. I found this [notebook](https://colab.research.google.com/drive/1CQvYpR-c-XAyAmHcdiSoC0wMNYW-ntGM?usp=sharing#scrollTo=KoGdlaBWpTDZ) that does a nice intro, but this person built their own classes for the optimization part.

In [17]:
import jax.numpy as jnp
from jax import jit, value_and_grad

In [28]:
#-------------------------#
#
# JAX version, just residual
#
#-------------------------#

def jax_eval1(new_grid, grid_shape,grid_def, ens_mem,min_lat_idx, min_lon_idx, c_ai):
    displaced_ens = grib2io.interpolate_to_stations(ens_mem, 'bilinear', grid_def, new_grid[1], new_grid[0])
    # interp_to_stations returns a list of points instead of the 2D field. just need to rearrange below
    coalesced_ensemble = jnp.zeros(jnp.shape(ens_mem))
    y,x = grid_shape
    for i in range(x):
        for j in range(y):
            point = i + x*j
            coalesced_ensemble[j+min_lat_idx, i+min_lon_idx] = displaced_ens[point] * (1. + c_ai[point])
    return coalesced_ensemble

def jax_residual_error(ck_arr,b_ik,ens_mean, ens_mem, obs_error, grid, grid_def, min_lat, min_lon):
    #grid point displacements from coeff * basis funcs
    ntrunc = jnp.shape(b_ik)[1]
    c_k = {"x": ck_arr[:ntrunc], "y": ck_arr[ntrunc:ntrunc*2], "a": ck_arr[ntrunc*2:]}
    gefs_grid_points = jnp.reshape(jnp.array([grid[0], grid[1]]), (2, -1), order='C').T
    grid_shape = jnp.shape(grid[0])
    c_xi = jnp.dot(b_ik, c_k['x'])
    c_yi = jnp.dot(b_ik, c_k['y'])
    n_x_i = gefs_grid_points[:,0] + c_xi
    n_y_i = gefs_grid_points[:,1] + c_yi
   
    #amplitude displacements
    c_ai = jnp.dot(b_ik, c_k['a'])

    displaced_ens = evaluate_at_new_grid(jnp.array([n_x_i, n_y_i]).T,jnp.array([gefs_grid_points[:,0], gefs_grid_points[:,1]]).T, 
                                         grid_shape,ens_mem, min_lat,min_lon, c_ai)

    max_lat = min_lat + grid_shape[0]
    max_lon = min_lon + grid_shape[1]
    before_sum = (ens_mean[min_lat:max_lat, min_lon:max_lon] - displaced_ens[min_lat:max_lat, min_lon:max_lon])**2
    return jnp.sum(before_sum / obs_error[min_lat:max_lat, min_lon:max_lon]**2)


#------ driver -------#
obj_and_grad = jit(value_and_grad(jax_residual_error))


# set up, using some of the stuff above
c_k = {"x": np.repeat(0.01, d_o_f), "y": np.repeat(0.01, d_o_f), "a": np.repeat(0.01, d_o_f)}
ck_arr = np.ravel(np.array([c_k['x'], c_k['y'], c_k['a']]))

grid_def = gefs_data.grib2io.griddef()
grid_shape = np.shape(gx)
grid = (gx, gy)
std = 0.3

In [29]:
test = obj_and_grad(ck_arr,b_ik,raw_mean, ens_mem, std, grid, grid_def, min_lat, min_lon)

TypeError: Cannot interpret value of type <class 'grib2io._grib2io.Grib2GridDef'> as an abstract array; it does not have a dtype attribute

Same error as above :/