Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad evaluation of odeint gradient slower than finite differences #561

Closed
JohnGoertz opened this issue Aug 7, 2020 · 5 comments
Closed

Comments

@JohnGoertz
Copy link

Evaluating grad on a system of ODE's solved with autograd's odeint appears to be ~20x slower than estimating the gradient through finite differences. Is there something wrong here? Is it possible/necessary to calculate the jacobian of the ODE itself and supply that to odeint?

import matplotlib.pyplot as plt
import numpy as npo
import scipy as sci

import autograd.numpy as np
from autograd import grad
from autograd import jacobian
from autograd.scipy.integrate import odeint
from autograd.builtins import tuple

import time

        
def kinetics(C, t, ks):
    k1,k2 = ks
    Ca,Cb,Cc = C
    dCadt = -k1 * Ca
    dCbdt = k1 * Ca - k2 * Cb
    dCcdt = k2 * Cb
    return np.array([dCadt,dCbdt,dCcdt])
    
def ode_pred(params, y0, t):
    return odeint(kinetics, y0, t, tuple((params,)))
    
t = np.linspace(0,10,5)
true_y0 = np.array([1.,0.,0.])
true_ks = np.array([1.,1.])
true_y = ode_pred(true_ks,true_y0,t)

def L1_loss(pred, targets):
    return np.mean(np.abs(pred - targets))

def train_loss(params):
    pred = ode_pred(params, true_y0, t)
    return L1_loss(pred, true_y)

jac = grad(train_loss)
init_ks = true_ks*2
jac(init_ks)
jac(np.array([0.5,1.5]))

begin = time.time()
fit = sci.optimize.minimize(train_loss,init_ks)
end = time.time()
print('Optimization without jacobian:')
print(f'Time elapsed: {end-begin} seconds')
print(fit)

# Optimization without jacobian:
    
# Time elapsed: 0.36577939987182617 seconds

#       fun: 5.396792486514866e-09
#  hess_inv: array([[106.09494345, 108.44951657],
#        [108.44951657, 112.80408753]])
#       jac: array([-0.01530422,  0.01716303])
#   message: 'Desired error not necessarily achieved due to precision loss.'
#      nfev: 120
#       nit: 1
#      njev: 36
#    status: 2
#   success: False
#         x: array([0.99999983, 1.00000016])

begin = time.time()
fit = sci.optimize.minimize(train_loss,init_ks,jac=jac)
end = time.time()

print('Optimization with jacobian:')
print(f'Time elapsed: {end-begin} seconds')
print(fit)

# Optimization with jacobian:
    
# Time elapsed: 16.415971755981445 seconds

#       fun: 2.001369694450681e-08
#  hess_inv: array([[104.68816163, 107.00018347],
#        [107.00018347, 111.31219715]])
#       jac: array([-0.01536626,  0.01710102])
#   message: 'Desired error not necessarily achieved due to precision loss.'
#      nfev: 76
#       nit: 1
#      njev: 64
#    status: 2
#   success: False
#         x: array([0.99999938, 1.00000062])
@duvenaud
Copy link
Contributor

duvenaud commented Aug 7, 2020

Hmmm, there's not necessarily a problem, especially if the adjoint system is hard to solve. I would try reducing the error tolerances. But in general there's a lot of overhead when using autograd - if you want performance and a similar API, you should try Jax

@JohnGoertz
Copy link
Author

I will try Jax, but actually, autograd seems to be estimating the gradient wrong:

def simple(y, t, args):
    return y
    
def ode_pred(y0, t):
    return odeint(simple, y0, t, tuple((y0,)))[-1,:]
    
dims = 10
t = np.arange(20)
true_y0 = np.ones(dims)
true_y = ode_pred(true_y0,t)

def L1_loss(pred, targets):
    return np.mean(np.abs(pred - targets))

def train_loss(y0):
    pred = ode_pred(y0, t)
    return L1_loss(pred, true_y)


jac = grad(train_loss)
np.random.seed(1984)
init_y = np.random.random(dims)*2
print('init_y:')
print(init_y)
print(f'Estimated jacobian:')
print(jac(init_y))
print('')

Gives

init_y:
[0.01751146 0.59749114 0.11282815 0.96537384 0.56368644 1.58854801
 0.58361514 1.46951699 0.44313355 1.4328588 ]

Estimated jacobian:
[-17848237.63382979 -17848237.63382979 -17848237.63382979
 -17848237.63382979 -17848237.63382979  17848237.63382979
 -17848237.63382979  17848237.63382979 -17848237.63382979
  17848237.63382979]

I'm no mathematician, but that seems... very wrong.

@duvenaud
Copy link
Contributor

duvenaud commented Aug 7, 2020

It might not be wrong - the gradient of L1 loss depends only on the sign. I would compare against finite differences.

@JohnGoertz JohnGoertz changed the title grad evaluation of odeint gradient slower that finite differences grad evaluation of odeint gradient slower than finite differences Aug 7, 2020
@JohnGoertz
Copy link
Author

Okay, so JAX is significantly faster than autograd all around. Thanks for the tips!

In case anyone else encounters similar difficulties, here's a somewhat verbose comparison of gradient evaluation and ODE parameter estimation using stock numpy, numba compilation, autograd, and JAX. Here, autograd took 1000x longer to estimate the gradient of a loss function than the loss function itself. Ultimately, compiling the loss function with JAX and supplying the JAX-compiled gradient to minimize has the best performance, about 2-x3x faster than numba-compiling just the RHS. Of course, the sacrifice is a few seconds of compilation time. My real application shows that JAX-compilation of the RHS scales very well, and is preferable to numpy, although JAX evaluation and compilation of very the gradient of large systems (~1000 equations) gets bogged down, so YMMV.

Benchmarks:

  • STOCK NUMPY
    • Loss function evaluation time:
      • 58.2 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    • Minimization time:
      • 301 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • NUMBA COMPILED RHS
    • RHS Compilation time:
      • 0.21 s
    • Loss function evaluation time:
      • 31.4 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    • Minimization time:
      • 91.8 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
  • AUTOGRAD
    • Loss function evaluation time:
      • 158 µs ± 4.95 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    • Gradient Evaluation time:
      • 134 ms ± 16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    • Minimization time without gradient:
      • 953 ms ± 46.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    • Minimization time with gradient:
      • 48.16 s
  • JAX
    • Loss function evaluation time:
      • 1.28 ms ± 17.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    • Gradient Evaluation time:
      • 4.15 ms ± 676 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    • Gradient Compilation time:
      • 2.01 s
    • Loss function compilation time:
      • 0.65 s
    • Compiled loss function evaluation time:
      • 85 µs ± 575 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    • Minimization time, uncompiled loss, no gradient:
      • 388 ms ± 9.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    • Minimization time, uncompiled loss, with uncompiled gradient:
      • 479 ms ± 7.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    • Minimization time, uncompiled loss, with compiled gradient:
      • 160 ms ± 2.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    • Minimization time, compiled loss, no gradient:
      • 56.5 ms ± 790 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    • Minimization time, compiled loss, with uncompiled gradient:
      • 363 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    • Minimization time, compiled loss, with compiled gradient:
      • 38.5 ms ± 746 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Code:

import numpy as onp
import autograd as ag
import jax
import autograd.numpy as anp
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
from jax.experimental.ode import odeint as jodeint
from autograd.builtins import tuple as agtuple
from autograd.scipy.integrate import odeint as agodeint
from numba import njit

from scipy.integrate import odeint
from scipy.optimize import minimize
import time

kwargs = dict(method='BFGS',options={'maxiter':50})

# Stock numpy
def kinetics(C, t, ks):
    k1 = ks[0]
    k2 = ks[1]
    Ca,Cb,Cc = C
    dCadt = -k1 * Ca
    dCbdt = k1 * Ca - k2 * Cb
    dCcdt = k2 * Cb
    return onp.array([dCadt,dCbdt,dCcdt])

t = onp.arange(60.)
true_y0 = onp.array([1.,0.,0.], dtype=float)
true_ks = onp.array([1.,1.], dtype=float)
init_ks = onp.array([0.,0.], dtype=float)
true_y = odeint(kinetics, true_y0, t, args=(true_ks,))

def loss(params):
    pred = odeint(kinetics, true_y0, t, args=(params,))
    return onp.sqrt(onp.mean(onp.square(pred - true_y)))

print('STOCK NUMPY')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Minimization time:')
%timeit minimize(loss,init_ks,**kwargs)

# Numba compiled rhs
kinetics = njit(kinetics)

print('NUMBA COMPILED RHS')
print('')
print('RHS Compilation time:')
begin = time.time()
kinetics(true_y0, None,init_ks)
print(f'{time.time()-begin:.2f} s')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Minimization time:')
%timeit minimize(loss,init_ks,**kwargs)

# Autograd
def kinetics(C, t, ks):
    k1 = ks[0]
    k2 = ks[1]
    Ca,Cb,Cc = C
    dCadt = -k1 * Ca
    dCbdt = k1 * Ca - k2 * Cb
    dCcdt = k2 * Cb
    return anp.array([dCadt,dCbdt,dCcdt])

t = anp.arange(60.)
true_y0 = anp.array([1.,0.,0.], dtype=float)
true_ks = anp.array([1.,1.], dtype=float)
init_ks = anp.array([0.,0.], dtype=float)
true_y = agodeint(kinetics, true_y0, t, agtuple((true_ks,)))

def loss(params):
    pred = agodeint(kinetics, true_y0, t, agtuple((params,)))
    return anp.sqrt(anp.mean(anp.square(pred - true_y)))

print('Autograd')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Gradient Evaluation time:')
ajac = ag.grad(loss)
%timeit ajac(init_ks)
print('')
print('Minimization time without gradient:')
%timeit minimize(loss,init_ks,**kwargs)
print('')
print('Minimization time with gradient:')
begin = time.time()
minimize(loss,init_ks,jac=ajac,**kwargs)
print(f'{time.time()-begin:.2f} s')


# JAX
def kinetics(C, t, ks):
    k1 = ks[0]
    k2 = ks[1]
    Ca,Cb,Cc = C
    dCadt = -k1 * Ca
    dCbdt = k1 * Ca - k2 * Cb
    dCcdt = k2 * Cb
    return jnp.array([dCadt,dCbdt,dCcdt])

t = jnp.arange(60.)
true_y0 = jnp.array([1.,0.,0.], dtype=float)
true_ks = jnp.array([1.,1.], dtype=float)
init_ks = jnp.array([0.,0.], dtype=float)
true_y = jodeint(kinetics, true_y0, t, true_ks)

def loss(params):
    pred = jodeint(kinetics, true_y0, t, params)
    return jnp.sqrt(jnp.mean(jnp.square(pred - true_y)))

print('JAX')
print('')
print('Loss function evaluation time:')
%timeit loss(init_ks)
print('')
print('Gradient Evaluation time:')
jjac = jax.grad(loss)
%timeit jjac(init_ks)
print('')
print('Gradient Compilation time:')
begin=time.time()
jjjac = jax.jit(jjac)
jjjac(init_ks)
print(f'{time.time()-begin:.2f} s')
print('')
print('Loss function compilation time:')
begin=time.time()
jloss = jax.jit(loss)
jloss(init_ks)
print(f'{time.time()-begin:.2f} s')
print('')


print('Minimization time, uncompiled loss, no gradient:')
%timeit minimize(loss,init_ks, **kwargs)
print('')
print('Minimization time, uncompiled loss, with uncompiled gradient:')
%timeit minimize(loss,init_ks,jac=jjac, **kwargs)
print('')
print('Minimization time, uncompiled loss, with compiled gradient:')
%timeit minimize(loss,init_ks,jac=jjjac, **kwargs)
print('')
print('Minimization time, compiled loss, no gradient:')
%timeit minimize(jloss,init_ks, **kwargs)
print('')
print('Minimization time, compiled loss, with uncompiled  gradient:')
%timeit minimize(jloss,init_ks,jac=jjac, **kwargs)
print('')
print('Minimization time, compiled loss, with compiled gradient:')
%timeit minimize(jloss,init_ks,jac=jjjac, **kwargs)

@duvenaud
Copy link
Contributor

Thanks for including these detailed benchmarks. I wonder why autograd was so slow here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@duvenaud @JohnGoertz and others