In [19]:
%load_ext autoreload
%autoreload 2
import jax
import jax.numpy as jnp
from diffrax import (
    ODETerm, diffeqsolve,
    SaveAt, RecursiveCheckpointAdjoint
)
from tools.integrate import integrate 
from tools.semi_implicite_euler import SemiImplicitEuler

jax.config.update("jax_enable_x64", True)

def velocity_fn(t, a, args):
    """
    Example system: dv/dt = -alpha * a + sin(t).
    We incorporate a parameter alpha from 'args'.
    """
    return -args * a + jnp.sin(t)

def acceleration_fn(t, v, args):
    """
    Example system: da/dt = -alpha * v + sin(t).
    Also depends on the same parameter alpha.
    """
    return -args * v + jnp.sin(t)


# Global solver settings
t0 = 0.0
t1 = 1.
dt0 = 0.01

# We'll store solutions at a few points in time
saveat = SaveAt(t1=True)

# For demonstration, let's define a system for Diffrax usage:
terms = (ODETerm(velocity_fn), ODETerm(acceleration_fn))
diffrax_solver = SemiImplicitEuler()

y0 = (0.0, 0.0)
args = 2.


# For reference, a "Diffrax-based" solution function:
  
def diffrax_integrate(y0 , args):
  sol = diffeqsolve(
          terms,
          diffrax_solver,
          t0=t0, t1=t1, dt0=dt0,
          y0=y0,
          args=args,
          saveat=saveat,
          adjoint=RecursiveCheckpointAdjoint(checkpoints=10),
      )
  return jax.tree.map(lambda x: x[0], sol.ys)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
diffrax_integrate(y0, args)

(Array(0.26026448, dtype=float64), Array(0.24753947, dtype=float64))

In [21]:
integrate(y0, args, terms,solver=diffrax_solver, t0=t0, t1=t1, dt0=dt0)

(Array(0.26026448, dtype=float64), Array(0.24753947, dtype=float64))

In [22]:
jax.jacrev(diffrax_integrate)(y0, args)

((Array(3.72580927, dtype=float64, weak_type=True),
  Array(-3.62655369, dtype=float64, weak_type=True)),
 (Array(-3.62655369, dtype=float64, weak_type=True),
  Array(3.79834034, dtype=float64, weak_type=True)))

In [23]:
jax.jacrev(integrate)(y0, args, terms,solver=diffrax_solver, t0=t0, t1=t1, dt0=dt0)

((Array(3.72580927, dtype=float64, weak_type=True),
  Array(-3.62655369, dtype=float64, weak_type=True)),
 (Array(-3.62655369, dtype=float64, weak_type=True),
  Array(3.79834034, dtype=float64, weak_type=True)))

In [24]:
jax.jacrev(diffrax_integrate , argnums=1)(y0, args)

(Array(-0.05741368, dtype=float64, weak_type=True),
 Array(-0.06787917, dtype=float64, weak_type=True))

In [28]:
jax.jacrev(integrate , argnums=1)(y0, args, terms,solver=diffrax_solver, t0=t0, t1=t1, dt0=dt0)

(Array(-0.05741368, dtype=float64, weak_type=True),
 Array(-0.06787917, dtype=float64, weak_type=True))

In [5]:
from tools.integrate import integrate

integrate(y0, args , terms , diffrax_solver , t0 , t1 , dt0)

STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))


(Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))

In [12]:
jax.jacrev(integrate)(y0, args , terms , diffrax_solver , t0 , t1 , dt0)

t0: 0.0, t1: 0.2, dt0: 0.1
STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
REV t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
REV t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))


((Array(1.04, dtype=float64, weak_type=True),
  Array(-0.408, dtype=float64, weak_type=True)),
 (Array(-0.408, dtype=float64, weak_type=True),
  Array(1.1216, dtype=float64, weak_type=True)))

In [297]:
def body(carry , t):
  jax.debug.print("t {a}" , a=t)
  return None , None


jax.lax.scan(body , None , jnp.arange(t1 , t0 , -dt0))

t 0.2
t 0.1


(None, None)

In [304]:
def jax_integrate(y0 , args):
  def forward_step(carry , t0):
      y, args = carry
      t1 = t0 + dt0
      y_next , _ , _ ,_ , _ = diffrax_solver.step(terms, t0, t1 ,  y, args , solver_state=None , made_jump=False)
      return (y_next, args), None


  init_carry = (y0, args)

  jax_sol , _ = jax.lax.scan(forward_step, init_carry, jnp.arange(t0, t1, dt0))

  return jax_sol[0]



@jax.custom_vjp
def integrate(y0 , args):
  def forward_step(carry , t0):
      y, args = carry
      t1 = t0 + dt0
      y_next , _ , _ ,_ , _ = diffrax_solver.step(terms, t0, t1 ,  y, args , solver_state=None , made_jump=False)
      return (y_next, args), None


  init_carry = (y0, args)

  jax_sol , _ = jax.lax.scan(forward_step, init_carry, jnp.arange(t0, t1, dt0))

  return jax_sol[0]


def integrate_fwd(y0 , args):
  return integrate(y0 , args) , (integrate(y0 , args), args)


def integrate_bwd(res , ct):

  y1 , args = res
  d_args = jax.tree.map(jnp.zeros_like , args)

  def step(carry , t1):

    y , dy , d_args = carry
    t0 = t1 - dt0
    y_prev  = diffrax_solver.reverse(terms, t0, t1 ,  y, args , solver_state=None , made_jump=False)

    xct , yct = dy

    
    def _to_vjp(y , z_args):
      y_next , _ , _ ,_ , _ = diffrax_solver.step(terms, t0, t1 ,  y, z_args , solver_state=None , made_jump=False)
      return y_next

    _ , f_vjp = jax.vjp(_to_vjp , y_prev , args)

    dy , dargs = f_vjp(dy)


    new_d_args = jax.tree.map(jnp.add , d_args , dargs)
    
    return (y_prev, dy ,  new_d_args), None

  init_carry = (y1, ct ,  d_args)
  (y , d_y , dargs) , _ = jax.lax.scan(step, init_carry, jnp.arange(t1, t0, -dt0))


  return d_y , d_args

integrate.defvjp(integrate_fwd , integrate_bwd)

In [299]:
y1 = integrate(y0, args)
y1

STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))


(Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))

In [300]:
jax_integrate(y0, args)

STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))


(Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))

In [303]:
jax.jacrev(jax_integrate , argnums=1)(y0, args)

(Array(0., dtype=float64, weak_type=True),
 Array(-0.00099833, dtype=float64, weak_type=True))

In [294]:
jax.jacrev(integrate , argnums=1)(y0, args)

STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
STEP t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
REV t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
STEP t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
xct 0.0 yct 1.0 dargs -0.0009983341664682817 and d_args 0.0 args 2.0 new_d_args  -0.0009983341664682817 
xct 1.0 yc

(Array(0., dtype=float64, weak_type=True),
 Array(0., dtype=float64, weak_type=True))

In [233]:
jax.jacrev(diffrax_integrate, argnums=1)(y0, args)

({'alpha': Array(0., dtype=float64, weak_type=True)},
 {'alpha': Array(-0.00099833, dtype=float64, weak_type=True)})

In [271]:
def reverse_step(carry , t1):
    y, args = carry
    t0 = t1 - dt0
    y_next  = diffrax_solver.reverse(terms, t0, t1 ,  y, args , solver_state=None , made_jump=False)
    return (y_next, args), None


init_carry = (y1, args)

jax_sol , _ = jax.lax.scan(reverse_step, init_carry, jnp.arange(t1, t0, -dt0))


REV t0 0.1 , t1 0.2 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0.00998334, dtype=float64), Array(0.00798667, dtype=float64))
REV t0 0.0 , t1 0.1 , y0 (Array(0., dtype=float64), Array(0., dtype=float64)) , y1 (Array(0., dtype=float64), Array(0., dtype=float64))


In [79]:
jax_sol

((Array(-0.03840824, dtype=float64), Array(-0.03475898, dtype=float64)),
 {'alpha': Array(2., dtype=float64, weak_type=True)})