In [1]:
from warnings import simplefilter
simplefilter("ignore", category=FutureWarning)

import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array

from test.helpers import path_l2_dist

jax.config.update("jax_enable_x64", True)

## Set `jump_ts` so that some steps must be clipped very short

In [2]:
# harmonic oscillator ODE
y0 = jnp.array([1.0, 0.0])
def vf(t, y, args):
    return jnp.array([y[1], -y[0]])

term = diffrax.ODETerm(vf)

solver = diffrax.Dopri5()
t0 = 0
t1 = 17
dt0 = None

jump_ts = []
for i in range(1, int(t1)):
    jump_ts.append(i)
    jump_ts.append(i + 0.0001)
jump_ts = jnp.array(jump_ts, dtype=jnp.float64)

save_ts = jnp.linspace(t0, t1, 101, endpoint=True, dtype=jnp.float64)
saveat = diffrax.SaveAt(ts=save_ts)

ref_pid = diffrax.PIDController(atol=1e-8, rtol=0)
ref_contr = diffrax.JumpStepWrapper(ref_pid, step_ts=save_ts)
ref_sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0,
    stepsize_controller=ref_contr,
    saveat=saveat
)

In [3]:
pid_controller = diffrax.PIDController(atol=1e-5, rtol=0)
my_controller = diffrax.JumpStepWrapper(pid_controller, jump_ts=jump_ts)
patricks_controller = diffrax.JumpStepWrapper(pid_controller, jump_ts=jump_ts, use_patricks_version=True)
my_sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0,
    stepsize_controller=my_controller,
    saveat=saveat
)
patricks_sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0,
    stepsize_controller=patricks_controller,
    saveat=saveat
)

# Compare both solutions to ref_sol
my_error = path_l2_dist(ref_sol.ys, my_sol.ys)
patricks_error = path_l2_dist(ref_sol.ys, patricks_sol.ys)
my_steps = my_sol.stats["num_steps"]
my_rejected = my_steps - my_sol.stats["num_accepted_steps"]
patricks_steps = patricks_sol.stats["num_steps"]
patricks_rejected = patricks_steps - patricks_sol.stats["num_accepted_steps"]

print(f"My error: {my_error:.5}, Patrick's error: {patricks_error:.5}")
print(f"My num_steps: {my_sol.stats['num_steps']}, Patrick's num_steps: {patricks_sol.stats['num_steps']}")
print(f"My num_rejected: {my_rejected}, Patrick's num_rejected: {patricks_rejected}")

My error: 3.5409e-05, Patrick's error: 3.4735e-05
My num_steps: 69, Patrick's num_steps: 85
My num_rejected: 0, Patrick's num_rejected: 16


Clearly your version rejects exactly one step after each pair of `i, i+0.0001` in `jump_ts` (there are 16 such pairs). This is because `prev_dt * factor` is way too big. `prev_dt` was indeed chosen "optimally", but `factor` was computed using the error from the step `t1-t0` in a way that would make `(t1-t0) * factor` close to optimal, but since `prev_dt >> t1-t0`, then also `prev_dt * factor` is way too big and gets rejected.