In [11]:
%load_ext autoreload
%autoreload 2
import os
import sys

parent_dir = os.path.abspath("..")
sys.path.append(parent_dir)

import jax
import jax.numpy as jnp
from jax import custom_vjp
from functools import partial
from typing import Any, Tuple , Optional
import numpy as np
from diffrax import ODETerm , diffeqsolve , SaveAt , RecursiveCheckpointAdjoint
from equinox.internal import while_loop
import equinox as eqx
from tools.integrate import integrate , scan_integrate
from tools.semi_implicite_euler import SemiImplicitEuler

def check_tree(x , y):
    return jax.tree.all(jax.tree.map(lambda x , y : jnp.allclose(x , y) , x , y))

def f(t, x, z):
    return x + z * t


def g(t, y, z):
    return y + z * t


def diffrax_integrate(ode_terms , solver , y0 , args, t0 , t1 , dt0 , saveat , checkpoints=2):
    sol = diffeqsolve(ode_terms , solver , t0 , t1 , dt0 , y0 , args , saveat=saveat , adjoint=RecursiveCheckpointAdjoint(checkpoints=checkpoints))
    return sol.ys

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from itertools import product

y0 = (1.0, 1.0)
args = 2.0

t0 = 0.0
t1 = 10.0
dt0 = 1
ode_terms = (ODETerm(g) , ODETerm(f))
solver = SemiImplicitEuler()

for save_t0 , save_t1 in product([True, False], repeat=2):
  # Starting and ending t0 and t1
  saveat_t0_t1 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting t0 and ending before t1
  saveat_t0_tx = SaveAt(ts=jnp.arange(t0 , t1 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending at t1
  saveat_tx_t1 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending before t1
  saveat_tx_tx = SaveAt(ts=jnp.arange(t0 + dt0 , t1 - dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 before t1
  saveat_t0_tx_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Gettings steps starting after t0 and jumping steps by 2 * dt0 before t1
  saveat_tx_tx_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 at t1
  saveat_t0_t1_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting after t0 and jumping steps by 2 * dt0 at t1
  saveat_tx_t1_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)

  for saveat in [saveat_t0_t1 , saveat_t0_tx , saveat_tx_t1 , saveat_tx_tx , saveat_t0_tx_2 , saveat_tx_tx_2 , saveat_t0_t1_2 , saveat_tx_t1_2]:

      jax_fwd = scan_integrate(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)
      diffrax_fwd = diffrax_integrate(ode_terms , solver , y0, args, t0, t1, dt0 , saveat)
      my_fwd = integrate(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)

      assert check_tree(jax_fwd , diffrax_fwd)
      assert check_tree(jax_fwd , my_fwd)
      print(f"ok for saveat {saveat.subs.ts} with t0={saveat.subs.t0} and t1={saveat.subs.t1}")

  print("=====================================")


ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=True
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=False
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=False
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=False
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=False
ok for saveat [ 0.  2.  4. 

In [5]:
for save_t0 , save_t1 in product([True, False], repeat=2):
  # Starting and ending t0 and t1
  saveat_t0_t1 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting t0 and ending before t1
  saveat_t0_tx = SaveAt(ts=jnp.arange(t0 , t1 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending at t1
  saveat_tx_t1 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending before t1
  saveat_tx_tx = SaveAt(ts=jnp.arange(t0 + dt0 , t1 - dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 before t1
  saveat_t0_tx_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Gettings steps starting after t0 and jumping steps by 2 * dt0 before t1
  saveat_tx_tx_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 at t1
  saveat_t0_t1_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting after t0 and jumping steps by 2 * dt0 at t1
  saveat_tx_t1_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)

  for saveat in [saveat_t0_t1 , saveat_t0_tx , saveat_tx_t1 , saveat_tx_tx , saveat_t0_tx_2 , saveat_tx_tx_2 , saveat_t0_t1_2 , saveat_tx_t1_2]:


    jax_bwd = jax.jacrev(scan_integrate , argnums=(0 , 1))(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)
    diffrax_bwd = jax.jacrev(diffrax_integrate , argnums=(2 ,3))(ode_terms , solver , y0, args, t0, t1, dt0 , saveat)
    my_bwd = jax.jacrev(integrate , argnums=(0 , 1))(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)

    assert check_tree(jax_bwd , diffrax_bwd)
    assert check_tree(jax_bwd , my_bwd)
    print(f"ok for saveat {saveat.subs.ts} with t0={saveat.subs.t0} and t1={saveat.subs.t1}")


  print("=====================================")


ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=True
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=False
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=False
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=False
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=False
ok for saveat [ 0.  2.  4. 

In [6]:
y0 = (1.0, 1.0)
args = 2.0

t0 = 0.0
t1 = 10.0
dt0 = 1

for save_t0 , save_t1 in product([True, False], repeat=2):
  # Starting and ending t0 and t1
  saveat_t0_t1 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting t0 and ending before t1
  saveat_t0_tx = SaveAt(ts=jnp.arange(t0 , t1 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending at t1
  saveat_tx_t1 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending before t1
  saveat_tx_tx = SaveAt(ts=jnp.arange(t0 + dt0 , t1 - dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 before t1
  saveat_t0_tx_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Gettings steps starting after t0 and jumping steps by 2 * dt0 before t1
  saveat_tx_tx_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 at t1
  saveat_t0_t1_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting after t0 and jumping steps by 2 * dt0 at t1
  saveat_tx_t1_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)

  for saveat in [saveat_t0_t1 , saveat_t0_tx , saveat_tx_t1 , saveat_tx_tx , saveat_t0_tx_2 , saveat_tx_tx_2 , saveat_t0_t1_2 , saveat_tx_t1_2]:

    def fn(t , y , args):
      y =  jax.tree.map(lambda x : 2 * x + t * args , y)

      return y[0] + y[1]

    diffrax_fwd = diffrax_integrate(ode_terms , solver , y0, args, t0, t1, dt0 , saveat)
    my_fwd = integrate(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)

    assert check_tree(my_fwd , diffrax_fwd)
    print(f"ok for saveat {saveat.subs.ts} with t0={saveat.subs.t0} and t1={saveat.subs.t1}")



ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=True
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=False
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=False
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=False
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=False
ok for saveat [ 0.  2.  4. 

In [7]:
y0 = (1.0, 1.0)
args = 2.0

t0 = 0.0
t1 = 10.0
dt0 = 1


for save_t0 , save_t1 in product([True, False], repeat=2):
  # Starting and ending t0 and t1
  saveat_t0_t1 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting t0 and ending before t1
  saveat_t0_tx = SaveAt(ts=jnp.arange(t0 , t1 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending at t1
  saveat_tx_t1 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Starting after t0 and ending before t1
  saveat_tx_tx = SaveAt(ts=jnp.arange(t0 + dt0 , t1 - dt0 , dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 before t1
  saveat_t0_tx_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Gettings steps starting after t0 and jumping steps by 2 * dt0 before t1
  saveat_tx_tx_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting from t0 and jumping steps by 2 * dt0 at t1
  saveat_t0_t1_2 = SaveAt(ts=jnp.arange(t0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)
  # Getting steps starting after t0 and jumping steps by 2 * dt0 at t1
  saveat_tx_t1_2 = SaveAt(ts=jnp.arange(t0 + dt0 , t1 + dt0 , 2 * dt0) , t0=save_t0 , t1=save_t1)

  for saveat in [saveat_t0_t1 , saveat_t0_tx , saveat_tx_t1 , saveat_tx_tx , saveat_t0_tx_2 , saveat_tx_tx_2 , saveat_t0_t1_2 , saveat_tx_t1_2]:

    def fn(t , y , args):
      y =  jax.tree.map(lambda x : 2 * x + t * args , y)

      return y[0] + y[1]

    diffrax_bwd = jax.jacrev(diffrax_integrate , argnums=(2 ,3))(ode_terms , solver , y0, args, t0, t1, dt0 , saveat)
    my_bwd = jax.jacrev(integrate , argnums=(0 , 1))(y0, args, ode_terms , solver , t0, t1, dt0 , saveat)

    assert check_tree(my_bwd , diffrax_bwd)
    print(f"ok for saveat {saveat.subs.ts} with t0={saveat.subs.t0} and t1={saveat.subs.t1}")



ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=True
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=True
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=True
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=True
ok for saveat [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] with t0=True and t1=False
ok for saveat [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.] with t0=True and t1=False
ok for saveat [1. 2. 3. 4. 5. 6. 7. 8.] with t0=True and t1=False
ok for saveat [ 0.  2.  4.  6.  8. 10.] with t0=True and t1=False
ok for saveat [1. 3. 5. 7. 9.] with t0=True and t1=False
ok for saveat [ 0.  2.  4. 

In [48]:
def fn(t , y , args):
  y =  jax.tree.map(lambda x : 2 * x + t * args**4 , y)

  return y[0] + y[1]
  
saveat = SaveAt(ts=[3 , 8] , t1=save_t1 , fn=fn)

t0 = 0.0
t1 = 10.0
dt0 = 0.5

@jax.jit
@partial(jax.jacrev , argnums=(0 , 1))
def wrapped_integrate(y0 , args):
    return integrate(y0 , args , ode_terms , solver , t0 , t1 , dt0 , saveat)

@jax.jit
@partial(jax.jacrev , argnums=(0 , 1))
def wrapped_jax_integrate(y0 , args):
    return scan_integrate( y0 , args , ode_terms , solver  , t0 , t1 , dt0 , saveat)  


@jax.jit
@partial(jax.jacrev , argnums=(0 , 1))
def wrapped_diffrax_integrate(y0 , args):
    return diffrax_integrate(ode_terms , solver  , y0 , args , t0 , t1 , dt0 , saveat , checkpoints=10)

In [38]:
%timeit wrapped_jax_integrate(y0 , args)[0][-1].block_until_ready()  
%timeit wrapped_integrate(y0 , args)[0][-1].block_until_ready()  
%timeit wrapped_diffrax_integrate(y0 , args)[0][-1].block_until_ready()  

432 μs ± 81.6 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.19 ms ± 462 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
86.4 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
jax_mem = wrapped_jax_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes
my_mem = wrapped_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes
diffrax_mem = wrapped_diffrax_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes

print(f"Jax memory is {jax_mem} , My memory is {my_mem} , Diffrax memory is {diffrax_mem}")

Jax memory is 2736 , My memory is 2689 , Diffrax memory is 16080


In [45]:

jax_mem = wrapped_jax_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes
my_mem = wrapped_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes
diffrax_mem = wrapped_diffrax_integrate.lower(y0 , args).compile().memory_analysis().temp_size_in_bytes

print(f"Jax memory is {jax_mem} , My memory is {my_mem} , Diffrax memory is {diffrax_mem}")

Jax memory is 2872 , My memory is 2561 , Diffrax memory is 16344


In [49]:
wrapped_integrate(y0 , args)

((Array([  33.674805, 4748.88    ], dtype=float32, weak_type=True),
  Array([  43.101074, 6082.253   ], dtype=float32, weak_type=True)),
 Array([ 236.93848, 8933.692  ], dtype=float32, weak_type=True))

In [None]:
wrapped_diffrax_integrate(y0 , args)

((Array([  33.674805, 4748.88    ], dtype=float32, weak_type=True),
  Array([  43.101074, 6082.253   ], dtype=float32, weak_type=True)),
 Array([ 236.93848, 8933.692  ], dtype=float32, weak_type=True))

: 