In [14]:
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -jnp.exp(-t)

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

In [15]:
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return {"e":-jnp.exp(-t)}

term = ODETerm(f)
solver = Dopri5()
y0 = {"e": jnp.array([2., 3.])}
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

In [16]:
print(solution)

Solution(
  t0=f32[],
  t1=f32[],
  ts=f32[1],
  ys={'e': f32[1,2]},
  interpolation=None,
  stats={
    'max_steps': 4096,
    'num_accepted_steps': weak_i32[],
    'num_rejected_steps': weak_i32[],
    'num_steps': weak_i32[]
  },
  result=EnumerationItem(_value=i32[], _enumeration=diffrax._solution.RESULTS),
  solver_state=None,
  controller_state=None,
  made_jump=None,
  event_mask=None
)


In [None]:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller)

print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])

[0. 1. 2. 3.]
[1.         0.3678826  0.13533902 0.04978956]


In [1]:
# pip install diffrax jax jaxlib  # if needed
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt

# dy/dt = y, y(0) = 1  → exact solution is y(t) = exp(t)
def f(t, y, args):
    return y

term   = ODETerm(f)
solver = Tsit5()

t0, t1 = 0.0, 1.0
y0     = jnp.array(1.0)
dt0    = 0.1                      # initial step size guess

ts = jnp.linspace(t0, t1, 11)     # 11 uniform samples between 0 and 1
sol = diffeqsolve(term, solver, t0=t0, t1=t1, dt0=dt0, y0=y0, saveat=SaveAt(ts=ts))

print("times:", sol.ts)           # -> [0. , 0.1, ..., 1.0]
print("values:", sol.ys)          # -> approximates exp(ts)




times: [0.         0.1        0.2        0.3        0.4        0.5
 0.6        0.7        0.8        0.90000004 1.        ]
values: [1.        1.105171  1.2214029 1.3498588 1.4918247 1.6487212 1.8221189
 2.0137525 2.2255409 2.4596033 2.7182817]


In [2]:
jnp.exp(sol.ts)  # exact values for comparison

Array([1.       , 1.105171 , 1.2214028, 1.3498588, 1.4918246, 1.6487212,
       1.8221189, 2.0137527, 2.2255409, 2.4596033, 2.7182817],      dtype=float32)

In [6]:
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
import equinox as eqx
from diffrax import ODETerm, Tsit5, diffeqsolve, SaveAt

# MLP that maps [y, t] -> dy/dt (same shape as y)
key = jr.PRNGKey(0)
mlp = eqx.nn.MLP(
    in_size=2, out_size=1, width_size=32, depth=2,
    activation=jnn.tanh,
    # NOTE: do NOT set final_activation=None; default is identity.
    final_activation=lambda x: x,  # (optional explicit)
    key=key
)

def vf(t, y, net):
    # concatenate state and time to make a time-dependent vector field
    x = jnp.concatenate([jnp.atleast_1d(y), jnp.atleast_1d(t)])  # shape (2,)
    return net(x)  # must match y's shape; here (1,)

term   = ODETerm(vf)
solver = Tsit5()

t0, t1 = 0.0, 1.0
y0     = jnp.array([1.0])            # vector state (shape (1,))
ts     = jnp.linspace(t0, t1, 101)

sol = diffeqsolve(
    term, solver, t0=t0, t1=t1, dt0=1e-2, y0=y0,
    args=mlp, saveat=SaveAt(ts=ts)
)

print(sol.ts.shape, sol.ys.shape)
print(sol.ts)  # time points
print(sol.ys)  # solution at time points


(101,) (101, 1)
[0.         0.01       0.02       0.03       0.04       0.05
 0.06       0.07       0.08       0.09       0.09999999 0.11
 0.12       0.13       0.14       0.14999999 0.16       0.17
 0.17999999 0.19       0.19999999 0.21       0.22       0.22999999
 0.24       0.25       0.26       0.26999998 0.28       0.29
 0.29999998 0.31       0.32       0.32999998 0.34       0.35
 0.35999998 0.37       0.38       0.39       0.39999998 0.41
 0.42       0.42999998 0.44       0.45       0.45999998 0.47
 0.48       0.48999998 0.5        0.51       0.52       0.53
 0.53999996 0.55       0.56       0.57       0.58       0.59
 0.59999996 0.61       0.62       0.63       0.64       0.65
 0.65999997 0.66999996 0.68       0.69       0.7        0.71
 0.71999997 0.72999996 0.74       0.75       0.76       0.77
 0.78       0.78999996 0.79999995 0.81       0.82       0.83
 0.84       0.84999996 0.85999995 0.87       0.88       0.89
 0.9        0.90999997 0.91999996 0.93       0.94       0.95
 0