-
Notifications
You must be signed in to change notification settings - Fork 905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grad evaluation of odeint gradient slower than finite differences #561
Comments
Hmmm, there's not necessarily a problem, especially if the adjoint system is hard to solve. I would try reducing the error tolerances. But in general there's a lot of overhead when using autograd - if you want performance and a similar API, you should try Jax |
I will try Jax, but actually, autograd seems to be estimating the gradient wrong: def simple(y, t, args):
return y
def ode_pred(y0, t):
return odeint(simple, y0, t, tuple((y0,)))[-1,:]
dims = 10
t = np.arange(20)
true_y0 = np.ones(dims)
true_y = ode_pred(true_y0,t)
def L1_loss(pred, targets):
return np.mean(np.abs(pred - targets))
def train_loss(y0):
pred = ode_pred(y0, t)
return L1_loss(pred, true_y)
jac = grad(train_loss)
np.random.seed(1984)
init_y = np.random.random(dims)*2
print('init_y:')
print(init_y)
print(f'Estimated jacobian:')
print(jac(init_y))
print('') Gives
I'm no mathematician, but that seems... very wrong. |
It might not be wrong - the gradient of L1 loss depends only on the sign. I would compare against finite differences. |
Okay, so JAX is significantly faster than autograd all around. Thanks for the tips! In case anyone else encounters similar difficulties, here's a somewhat verbose comparison of gradient evaluation and ODE parameter estimation using stock numpy, numba compilation, autograd, and JAX. Here, autograd took 1000x longer to estimate the gradient of a loss function than the loss function itself. Ultimately, compiling the loss function with JAX and supplying the JAX-compiled gradient to Benchmarks:
Code: import numpy as onp
import autograd as ag
import jax
import autograd.numpy as anp
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
from jax.experimental.ode import odeint as jodeint
from autograd.builtins import tuple as agtuple
from autograd.scipy.integrate import odeint as agodeint
from numba import njit
from scipy.integrate import odeint
from scipy.optimize import minimize
import time
kwargs = dict(method='BFGS',options={'maxiter':50})
# Stock numpy
def kinetics(C, t, ks):
k1 = ks[0]
k2 = ks[1]
Ca,Cb,Cc = C
dCadt = -k1 * Ca
dCbdt = k1 * Ca - k2 * Cb
dCcdt = k2 * Cb
return onp.array([dCadt,dCbdt,dCcdt])
t = onp.arange(60.)
true_y0 = onp.array([1.,0.,0.], dtype=float)
true_ks = onp.array([1.,1.], dtype=float)
init_ks = onp.array([0.,0.], dtype=float)
true_y = odeint(kinetics, true_y0, t, args=(true_ks,))
def loss(params):
pred = odeint(kinetics, true_y0, t, args=(params,))
return onp.sqrt(onp.mean(onp.square(pred - true_y)))
print('STOCK NUMPY')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Minimization time:')
%timeit minimize(loss,init_ks,**kwargs)
# Numba compiled rhs
kinetics = njit(kinetics)
print('NUMBA COMPILED RHS')
print('')
print('RHS Compilation time:')
begin = time.time()
kinetics(true_y0, None,init_ks)
print(f'{time.time()-begin:.2f} s')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Minimization time:')
%timeit minimize(loss,init_ks,**kwargs)
# Autograd
def kinetics(C, t, ks):
k1 = ks[0]
k2 = ks[1]
Ca,Cb,Cc = C
dCadt = -k1 * Ca
dCbdt = k1 * Ca - k2 * Cb
dCcdt = k2 * Cb
return anp.array([dCadt,dCbdt,dCcdt])
t = anp.arange(60.)
true_y0 = anp.array([1.,0.,0.], dtype=float)
true_ks = anp.array([1.,1.], dtype=float)
init_ks = anp.array([0.,0.], dtype=float)
true_y = agodeint(kinetics, true_y0, t, agtuple((true_ks,)))
def loss(params):
pred = agodeint(kinetics, true_y0, t, agtuple((params,)))
return anp.sqrt(anp.mean(anp.square(pred - true_y)))
print('Autograd')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Gradient Evaluation time:')
ajac = ag.grad(loss)
%timeit ajac(init_ks)
print('')
print('Minimization time without gradient:')
%timeit minimize(loss,init_ks,**kwargs)
print('')
print('Minimization time with gradient:')
begin = time.time()
minimize(loss,init_ks,jac=ajac,**kwargs)
print(f'{time.time()-begin:.2f} s')
# JAX
def kinetics(C, t, ks):
k1 = ks[0]
k2 = ks[1]
Ca,Cb,Cc = C
dCadt = -k1 * Ca
dCbdt = k1 * Ca - k2 * Cb
dCcdt = k2 * Cb
return jnp.array([dCadt,dCbdt,dCcdt])
t = jnp.arange(60.)
true_y0 = jnp.array([1.,0.,0.], dtype=float)
true_ks = jnp.array([1.,1.], dtype=float)
init_ks = jnp.array([0.,0.], dtype=float)
true_y = jodeint(kinetics, true_y0, t, true_ks)
def loss(params):
pred = jodeint(kinetics, true_y0, t, params)
return jnp.sqrt(jnp.mean(jnp.square(pred - true_y)))
print('JAX')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Gradient Evaluation time:')
jjac = jax.grad(loss)
%timeit jjac(init_ks)
print('')
print('Gradient Compilation time:')
begin=time.time()
jjjac = jax.jit(jjac)
jjjac(init_ks)
print(f'{time.time()-begin:.2f} s')
print('')
print('Loss function compilation time:')
begin=time.time()
jloss = jax.jit(loss)
jloss(init_ks)
print(f'{time.time()-begin:.2f} s')
print('')
print('Minimization time, uncompiled loss, no gradient:')
%timeit minimize(loss,init_ks, **kwargs)
print('')
print('Minimization time, uncompiled loss, with uncompiled gradient:')
%timeit minimize(loss,init_ks,jac=jjac, **kwargs)
print('')
print('Minimization time, uncompiled loss, with compiled gradient:')
%timeit minimize(loss,init_ks,jac=jjjac, **kwargs)
print('')
print('Minimization time, compiled loss, no gradient:')
%timeit minimize(jloss,init_ks, **kwargs)
print('')
print('Minimization time, compiled loss, with uncompiled gradient:')
%timeit minimize(jloss,init_ks,jac=jjac, **kwargs)
print('')
print('Minimization time, compiled loss, with compiled gradient:')
%timeit minimize(jloss,init_ks,jac=jjjac, **kwargs)
|
Thanks for including these detailed benchmarks. I wonder why autograd was so slow here. |
Evaluating
grad
on a system of ODE's solved with autograd'sodeint
appears to be ~20x slower than estimating the gradient through finite differences. Is there something wrong here? Is it possible/necessary to calculate the jacobian of the ODE itself and supply that toodeint
?The text was updated successfully, but these errors were encountered: