# Comparison of different available functions

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

In [2]:
# from zernipax import set_device
# set_device("gpu")

In [3]:
import numpy as np
np.set_printoptions(precision=3, threshold= sys.maxsize, suppress=False)
import mpmath
import matplotlib
import matplotlib.pyplot as plt
from zernipax.zernike import *
from zernipax.basis import ZernikePolynomial, FourierZernikeBasis
from zernipax.plotting import plot_basis, plot_comparison
from zernipax.backend import jax

using JAX backend, jax version=0.4.28, jaxlib version=0.4.28, dtype=float64
Using device: CPU, with 6.88 GB available memory


In [9]:
res = 12
# basis = ZernikePolynomial(L=res, M=res, spectral_indexing="ansi", sym="cos")
basis = FourierZernikeBasis(L=res, M=res, N=res)
r = np.linspace(0, 1, 1000)

In [10]:
dr = 0
jax.clear_caches()

_ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr)
_ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr)
_ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr)
_ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr)

%timeit zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

# %timeit -n 10 zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready(); jax.clear_caches()
# %timeit -n 10 zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready(); jax.clear_caches()
# %timeit zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready(); jax.clear_caches()
# %timeit -n 10 zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready(); jax.clear_caches()

12.3 ms ± 314 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.5 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
53.7 ms ± 843 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
print("zernike_radial_poly, 0th derivative")
%timeit _ = zernike_radial_poly(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=0, exact=False)
print("zernike_radial_poly, 1st derivative")
%timeit _ = zernike_radial_poly(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=1, exact=False)
print("zernike_radial_poly, 2nd derivative")
%timeit _ = zernike_radial_poly(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=2, exact=False)
print("zernike_radial_poly, 3rd derivative")
%timeit _ = zernike_radial_poly(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=3, exact=False)

In [None]:
# Exact computation
mpmath.mp.dps = 100
c = zernike_radial_coeffs(basis.modes[:, 0], basis.modes[:, 1], exact=True)
zt0 = np.array([np.asarray(mpmath.polyval(list(ci), r), dtype=float) for ci in c]).T
zt1 = np.array([np.asarray(mpmath.polyval(list(ci), r), dtype=float) for ci in polyder_vec(c, 1, exact=True)]).T
zt2 = np.array([np.asarray(mpmath.polyval(list(ci), r), dtype=float) for ci in polyder_vec(c, 2, exact=True)]).T
zt3 = np.array([np.asarray(mpmath.polyval(list(ci), r), dtype=float) for ci in polyder_vec(c, 3, exact=True)]).T
# Newest function
zr0 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 0)
zr1 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 1)
zr2 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 2)
zr3 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 3)
# Old Desc version
zd0 = zernike_radial_old_desc(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 0
)
zd1 = zernike_radial_old_desc(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 1
)
zd2 = zernike_radial_old_desc(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 2
)
zd3 = zernike_radial_old_desc(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 3
)

# Exact computation
mpmath.mp.dps = 10
# Polynomial computation
zp0 = zernike_radial_poly(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], dr=0, exact=False
)
zp1 = zernike_radial_poly(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], dr=1, exact=False
)
zp2 = zernike_radial_poly(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], dr=2, exact=False
)
zp3 = zernike_radial_poly(
    r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], dr=3, exact=False
)

In [None]:
plot_comparison(zt3, (zr3, zd3, zp3), basis, 1, "absolute")

In [7]:
dr = 0
print(f"zernike_radial, derivative order: {dr}")

print("# With no duplicate modes (might have lacking modes)")
%timeit _ = zernike_radial_unique(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

print("# With all the checks necessary but no reverse mode AutoDiff capable")
%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

print("# With all the checks necessary and reverse mode AutoDiff capable")
# %timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
# %timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
# %timeit _ = zernike_radial_jvp_gpu(r, basis.modes[:,0], basis.modes[:,1], dr, repeat=13).block_until_ready()

print("# With all the checks necessary but less efficient")
%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

zernike_radial, derivative order: 0
# With no duplicate modes (might have lacking modes)
81 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With all the checks necessary but no reverse mode AutoDiff capable
85 µs ± 341 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With all the checks necessary and reverse mode AutoDiff capable
85.1 µs ± 259 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
253 µs ± 652 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# With all the checks necessary but less efficient
303 µs ± 426 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
dr = 1
print(f"zernike_radial, derivative order: {dr}")

print("# With no duplicate modes (might have lacking modes)")
%timeit _ = zernike_radial_unique(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

print("# With all the checks necessary but no reverse mode AutoDiff capable")
%timeit _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

print("# With all the checks necessary and reverse mode AutoDiff capable")
%timeit _ = zernike_radial_switch(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit _ = zernike_radial_switch_gpu(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()
%timeit _ = zernike_radial_jvp(r, basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()

print("# With all the checks necessary but less efficient")
%timeit _ = zernike_radial_old_desc(r[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr).block_until_ready()