diff --git a/torchdyn/numerics/solvers/ode.py b/torchdyn/numerics/solvers/ode.py index 7406fb9..aa066c6 100644 --- a/torchdyn/numerics/solvers/ode.py +++ b/torchdyn/numerics/solvers/ode.py @@ -144,7 +144,7 @@ def __init__(self, dtype=torch.float32): def step(self, f, x, t, dt, k1=None, args=None) -> Tuple: c, a, bsol, berr = self.tableau - if k1 == None: k1 = f(t, x) + if k1 is None: k1 = f(t, x) k2 = f(t + c[0] * dt, x + dt * a[0] * k1) k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2)) k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3)