In [1]:
import numpy as np
import dedalus.public as d3

In [2]:
# Domain
coords = d3.CartesianCoordinates('x', 'y', 'z')
dist = d3.Distributor(coords, dtype=np.float64)
xbasis = d3.RealFourier(coords['x'], 64, bounds=(-1, 1))
ybasis = d3.RealFourier(coords['y'], 64, bounds=(-1, 1))
zbasis = d3.RealFourier(coords['z'], 64, bounds=(-1, 1))

In [3]:
# Fields
f = dist.Field(name='f', bases=(xbasis, ybasis, zbasis))
df = dist.Field(name='df', bases=(xbasis, ybasis, zbasis))
dg = dist.Field(name='dg', bases=(xbasis, ybasis, zbasis))

In [4]:
# Nested operator with large expression swell
f2 = 2 * f
f4 = f2 * f2
f8 = f4 * f4
G = f8 * f8

f2.name = 'f2'
f4.name = 'f4'
f8.name = 'f8'
G.name = 'G'

# Current symbolic Frechet derivative, with expression swell
dG = G.frechet_differential([f], [df])
print("Symbolic derivative:", dG)

# Test JVP vs symbolic Frechet derivative on random inputs
f.fill_random('g'); f['c']; f['g']
df.fill_random('g'); df['c']; df['g']

g, dg_fwd = G.evaluate_jvp({f: df})
dg_fwd = dg_fwd.copy()
dg_sym = dG.evaluate()
dg_sym = dg_sym.copy()

print("JVP matches symbolic:", np.allclose(dg_fwd['g'], dg_sym['g']))

Symbolic derivative: ((2*df*2*f + 2*f*2*df)*2*f*2*f + 2*f*2*f*(2*df*2*f + 2*f*2*df))*2*f*2*f*2*f*2*f + 2*f*2*f*2*f*2*f*((2*df*2*f + 2*f*2*df)*2*f*2*f + 2*f*2*f*(2*df*2*f + 2*f*2*df))
JVP matches symbolic: True


In [5]:
# Test VJP on random inputs
def inner(a, b):
    return np.sum(a['g'] * b['g'])
dg.fill_random('g'); dg['c']; dg['g']

g, df_rev = G.evaluate_vjp(dg, id=np.random.randint(0, 1000))

print("Inner products <random | Jacobian | random>:")
print("<R2 | J @ R1>  :", inner(df, df_rev[f]))
print("<J.T @ R2 | R1>:", inner(dg, dg_fwd))

Inner products <random | Jacobian | random>:
<R2 | J @ R1>  : 40309726.89599003
<J.T @ R2 | R1>: 40309726.89598999


In [6]:
print("Time forward evaluation:")
%timeit G.evaluate(id=np.random.randint(0, 1000000))

Time forward evaluation:
411 µs ± 6.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
print("Time JVP evaluation:")
%timeit G.evaluate_jvp({f: df}, id=np.random.randint(0, 1000000))

Time JVP evaluation:
1.57 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
print("Time symbolic derivative evaluation:")
%timeit dG.evaluate(id=np.random.randint(0, 1000000))

Time symbolic derivative evaluation:
9.55 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
print("Time VJP evaluation:")
%timeit g, df_rev = G.evaluate_vjp(dg, id=np.random.randint(0, 1000000))

Time VJP evaluation:
686 µs ± 17.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
