## MetPy but faster

This is a very loose notebook describing some thoughts and workflows to potentially make numpy faster (the engine for many of MetPy's calculations). -Thomas Martin, Unidata

In [1]:
import numpy as np
import metpy 
from metpy.units import units
import metpy.calc as mpcalc

## Creating fake data

In [2]:
num_samples = 10 ** 7

u = np.random.uniform(-5, 5, num_samples)
v = np.random.uniform(-5, 5, num_samples)

In [3]:
num_samples_humid = 10 ** 8

temp_in_k = np.random.uniform(270, 300, num_samples_humid)


## MetPy

Wind Speed

In [4]:
%%time
metpy_ws = mpcalc.wind_speed(u * units('m/s'), v * units('m/s'))

CPU times: user 2.33 s, sys: 15.2 ms, total: 2.35 s
Wall time: 2.38 s


Relative humidity calc

In [5]:
%%time
svp = mpcalc.saturation_vapor_pressure(temp_in_k *  units.degK)

CPU times: user 12 s, sys: 222 ms, total: 12.3 s
Wall time: 12.4 s


## JAX

JAX can be used with a CPU or GPU. This was all done on my M2 MacBookPro on the CPU. 

In [6]:
import jax
import jax.numpy as jnp

In [7]:
%%time
jax_ws = np.hypot(u, v)

CPU times: user 20.7 ms, sys: 4.31 ms, total: 25 ms
Wall time: 23.8 ms


In [8]:
metpy_ws.magnitude == jax_ws

array([ True,  True,  True, ...,  True,  True,  True])

In [9]:
from JAX_thermo import saturation_vapor_pressure

In [10]:
%%time
jax_svp = saturation_vapor_pressure(temp_in_k)

CPU times: user 265 ms, sys: 199 ms, total: 464 ms
Wall time: 394 ms


In [13]:
# Check if the arrays have the same length
if len(svp.magnitude) != len(jax_svp):
    print("Arrays have different lengths.")
else:
    # Check if each element in the arrays is within 0.001% tolerance
    all_within_tolerance = all(abs(a - b) / max(abs(a), abs(b)) <= 0.000001 for a, b in zip(svp.magnitude[:50000], jax_svp[:50000]))
    
    if all_within_tolerance:
        print("Arrays are within 0.0001% tolerance.")
    else:
        print("Arrays are not within 0.0001% tolerance.")

Arrays are within 0.0001% tolerance.


So the outputs are similar (not exact, due to floating point & unit handeling). But no question that JAX is faster (~4-10x?)

## Numba

Numba uses a 'just in time' compiler. A lot of info here: https://numba.pydata.org/

In [14]:
from numba import jit

In [15]:
@jit(nopython=True)
def hypot_numba(u, v):
    return np.sqrt(u**2 + v**2)

In [16]:
%%time
numba_ws = hypot_numba(u, v)

CPU times: user 241 ms, sys: 35 ms, total: 276 ms
Wall time: 397 ms


Slower than JAX

In [17]:
numba_ws == metpy_ws.magnitude

array([ True,  True,  True, ...,  True,  True,  True])