In [7]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import diffrax

In [8]:
def lorenz_system(t, state, args):
    x, y, z = state
    sigma, rho, beta = args
    dxdt = sigma * (y - x)
    dydt = x * (rho - z) - y
    dzdt = x * y - beta * z
    return jnp.array([dxdt, dydt, dzdt])

sigma=10.0
rho=28.0
beta=8.0/3.0
initial_state = jnp.array([1,1,1.])
t0=0.0
t1=20.0
dt=0.01
args = (sigma, rho, beta)
term = diffrax.ODETerm(lorenz_system)
controller = diffrax.PIDController(rtol = 1e-10,atol = 1e-10)
solver = diffrax.Tsit5()
t_span = (t0, t1)

save_at = diffrax.SaveAt(dense=True)  # Save at regular intervals
sol = diffrax.diffeqsolve(
	term,
	solver,
	t0=t0,
	t1=t1,
	dt0=dt,  # Initial step size
	y0=initial_state,
	args=args,
	saveat=save_at,
	max_steps = int(10*(t1-t0)/dt)
)

In [9]:
diffrax.citation(
    term,
	solver,
	t0=t0,
	t1=t1,
	dt0=dt,  # Initial step size
	y0=initial_state,
	args=args,
	saveat=save_at,
	max_steps = int(10*(t1-t0)/dt)
)

% --- AUTOGENERATED REFERENCES PRODUCED USING `diffrax.citation(...)` ---
% The following references were found for the numerical techniques being used.
% This does not cover e.g. any modelling techniques being used.
% If you think a paper is missing from here then open an issue or pull request at
% https://github.com/patrick-kidger/diffrax

% You are using Diffrax, which is citable as:
@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

% You are using Equinox, which is citable as:
@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and
           filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing
             Systems 2021}
}

% You are using JAX, which is citable as:
@software{jax2018github,
  author = {Ja