In [1]:
import jax.numpy as jnp
from jax import random, jit, config
import matplotlib.pyplot as plt
from functools import partial
config.update("jax_enable_x64", True)

%config InlineBackend.figure_format='retina'

In [2]:
! git init .
! git remote add origin https://github.com/VLSF/SDC
! git pull origin main

Reinitialized existing Git repository in /content/.git/
fatal: remote origin already exists.
From https://github.com/VLSF/SDC
 * branch            main       -> FETCH_HEAD
Already up to date.


In [3]:
from integrators import RK4, Explicit_Euler, Implicit_Euler

from sdc_integrators import RK4 as RK4_c
from sdc_integrators import Explicit_Euler as Explicit_Euler_c
from sdc_integrators import Implicit_Euler as Implicit_Euler_c
from misc import utils, Chebyshev, equations

# Test 1: errors are small, residuals are small

In [4]:
N_points = 20
N_sdc = 5
t0, t1 = 0.0, 1
t = (t1 - t0) * (Chebyshev.Chebyshev_grid(N_points) + 1)/2 + t0

integrators = {
    "RK4": RK4.integrator, 
    "Explicit Euler": Explicit_Euler.integrator,
    "Implicit Euler": Implicit_Euler.integrator,
    "Implicit Euler (jac)": Implicit_Euler.integrator_J,
}

for integrator in integrators:
    print(integrator)
    for equation in equations.equations_list:
        equation_data = equations.get_ODE(equation)
        if "exact" in equation_data:
            exact = equation_data["exact"]
            F = equation_data["F"]
            inv_dF = equation_data["inv_dF"]

            exact_solution = exact(t)
            exact_solution = jnp.expand_dims(exact_solution, 0) if exact_solution.ndim == 1 else exact_solution
            u0 = exact_solution[:, 0]

            if integrator == "Implicit Euler (jac)":
                values = integrators[integrator](u0, F, inv_dF, N_points, t0, t1, 1)
                for i in range(N_sdc):
                    values = Implicit_Euler_c.deferred_correction_J(values, F, inv_dF, t0, t1, 1)

            elif integrator == "Implicit Euler":
                values = integrators[integrator](u0, F, N_points, t0, t1, 1)
                for i in range(N_sdc):
                    values = Implicit_Euler_c.deferred_correction(values, F, t0, t1, 1)
            elif integrator == "Explitic Euler":
                values = integrators[integrator](u0, F, N_points, t0, t1)
                for i in range(N_sdc):
                    values = Explicit_Euler.deferred_correction(values, F, t0, t1)
            else:
                values = integrators[integrator](u0, F, N_points, t0, t1)
                for i in range(N_sdc):
                    values = RK4_c.deferred_correction(values, F, t0, t1)

            error = jnp.linalg.norm(jnp.ravel(values - exact_solution), ord=jnp.inf)
            res = jnp.linalg.norm(jnp.ravel(utils.residual(values, F, t0, t1)), ord=jnp.inf)

            print(equation)
            print("\terror ", error)
            print("\tresidual ", res)
        
        else:
            pass
    print("\n")



RK4
exp
	error  5.1958437552457326e-14
	residual  3.064215547965432e-14
Logistic
	error  1.1102230246251565e-16
	residual  5.551115123125783e-17
Harmonic oscillator
	error  9.564802323502875e-06
	residual  2.2796389846168934e-06
Prothero–Robinson
	error  1.0547118733938987e-14
	residual  1.3988810110276972e-14


Explicit Euler
exp
	error  6.5044769570477e-09
	residual  3.883880217436797e-09
Logistic
	error  1.454392162258955e-14
	residual  1.554312234475219e-14
Harmonic oscillator
	error  0.009383497396203477
	residual  0.005722407769449404
Prothero–Robinson
	error  7.836985504994232e-10
	residual  1.0692242091714377e-09


Implicit Euler
exp
	error  4.417052501537455e-08
	residual  3.035018325547867e-08
Logistic
	error  7.781553179597722e-13
	residual  8.666956041736285e-13
Harmonic oscillator
	error  0.050304683906040815
	residual  0.013359137559166675
Prothero–Robinson
	error  1.761027990099251e-10
	residual  2.2157548018597595e-10


Implicit Euler (jac)
exp
	error  4.417052501537455

# Test 2: order of convergence

In [5]:
def get_orders(N_sdc):
    N_points_1, N_points_2 = 20, 100
    t0, t1 = 0, 2
    t_1 = (t1 - t0) * (Chebyshev.Chebyshev_grid(N_points_1) + 1)/2 + t0
    t_2 = (t1 - t0) * (Chebyshev.Chebyshev_grid(N_points_2) + 1)/2 + t0

    integrators = {
        "RK4": RK4.integrator, 
        "Explicit Euler": Explicit_Euler.integrator,
        "Implicit Euler": Implicit_Euler.integrator,
        "Implicit Euler (jac)": Implicit_Euler.integrator_J,
    }

    for integrator in integrators:
        print(integrator)
        for equation in equations.equations_list:
            equation_data = equations.get_ODE(equation)
            if "exact" in equation_data:
                exact = equation_data["exact"]
                F = equation_data["F"]
                inv_dF = equation_data["inv_dF"]

                E = []
                for t, N_points in zip([t_1, t_2], [N_points_1, N_points_2]):
                    exact_solution = exact(t)
                    exact_solution = jnp.expand_dims(exact_solution, 0) if exact_solution.ndim == 1 else exact_solution
                    u0 = exact_solution[:, 0]

                    if integrator == "Implicit Euler (jac)":
                        values = integrators[integrator](u0, F, inv_dF, N_points, t0, t1, 1)
                        for i in range(N_sdc):
                            values = Implicit_Euler_c.deferred_correction_J(values, F, inv_dF, t0, t1, 1)

                    elif integrator == "Implicit Euler":
                        values = integrators[integrator](u0, F, N_points, t0, t1, 1)
                        for i in range(N_sdc):
                            values = Implicit_Euler_c.deferred_correction(values, F, t0, t1, 1)
                    elif integrator == "Explitic Euler":
                        values = integrators[integrator](u0, F, N_points, t0, t1)
                        for i in range(N_sdc):
                            values = Explicit_Euler.deferred_correction(values, F, t0, t1)
                    else:
                        values = integrators[integrator](u0, F, N_points, t0, t1)
                        for i in range(N_sdc):
                            values = RK4_c.deferred_correction(values, F, t0, t1)

                    error = jnp.linalg.norm(jnp.ravel(values - exact_solution), ord=jnp.inf)
                    E.append(error)

                slope = -jnp.log10(E[1]/E[0]) / jnp.log10(N_points_2/N_points_1)
                print(equation, ", ord =", jnp.round(slope.item(), decimals=2))
            
            else:
                pass
        print("\n")

## SDC(0)

In [6]:
get_orders(0)

RK4
exp , ord = 4.04
Logistic , ord = 4.12
Harmonic oscillator , ord = 4.04
Prothero–Robinson , ord = 4.13


Explicit Euler
exp , ord = 0.9500000000000001
Logistic , ord = 1.04
Harmonic oscillator , ord = 2.0300000000000002
Prothero–Robinson , ord = 1.04


Implicit Euler
exp , ord = 1.11
Logistic , ord = 0.9400000000000001
Harmonic oscillator , ord = 0.31
Prothero–Robinson , ord = 1.01


Implicit Euler (jac)
exp , ord = 1.11
Logistic , ord = 0.9400000000000001
Harmonic oscillator , ord = 0.31
Prothero–Robinson , ord = 1.01




## SDC(1)

In [7]:
get_orders(1)

RK4
exp , ord = 5.09
Logistic , ord = 5.08
Harmonic oscillator , ord = 5.04
Prothero–Robinson , ord = 5.14


Explicit Euler
exp , ord = 2.0
Logistic , ord = 1.99
Harmonic oscillator , ord = 2.91
Prothero–Robinson , ord = 2.06


Implicit Euler
exp , ord = 2.22
Logistic , ord = 1.95
Harmonic oscillator , ord = 0.85
Prothero–Robinson , ord = 2.0100000000000002


Implicit Euler (jac)
exp , ord = 2.22
Logistic , ord = 1.95
Harmonic oscillator , ord = 0.85
Prothero–Robinson , ord = 2.0100000000000002




## SDC(2)

In [8]:
get_orders(2)

RK4
exp , ord = 6.140000000000001
Logistic , ord = 6.04
Harmonic oscillator , ord = 6.0200000000000005
Prothero–Robinson , ord = 6.16


Explicit Euler
exp , ord = 3.0500000000000003
Logistic , ord = 2.97
Harmonic oscillator , ord = 3.97
Prothero–Robinson , ord = 3.0700000000000003


Implicit Euler
exp , ord = 3.33
Logistic , ord = 2.96
Harmonic oscillator , ord = 1.5
Prothero–Robinson , ord = 3.0100000000000002


Implicit Euler (jac)
exp , ord = 3.33
Logistic , ord = 2.96
Harmonic oscillator , ord = 1.5
Prothero–Robinson , ord = 3.0100000000000002




## SDC(3)

In [9]:
get_orders(3)

RK4
exp , ord = 7.19
Logistic , ord = 6.3100000000000005
Harmonic oscillator , ord = 6.98
Prothero–Robinson , ord = 7.17


Explicit Euler
exp , ord = 4.1
Logistic , ord = 3.95
Harmonic oscillator , ord = 4.86
Prothero–Robinson , ord = 4.09


Implicit Euler
exp , ord = 4.45
Logistic , ord = 3.97
Harmonic oscillator , ord = 2.16
Prothero–Robinson , ord = 4.0


Implicit Euler (jac)
exp , ord = 4.45
Logistic , ord = 3.97
Harmonic oscillator , ord = 2.16
Prothero–Robinson , ord = 4.0




## SDC(4)

In [10]:
get_orders(4)

RK4
exp , ord = 8.61
Logistic , ord = 3.66
Harmonic oscillator , ord = 7.96
Prothero–Robinson , ord = 8.040000000000001


Explicit Euler
exp , ord = 5.15
Logistic , ord = 4.93
Harmonic oscillator , ord = 5.89
Prothero–Robinson , ord = 5.1000000000000005


Implicit Euler
exp , ord = 5.5600000000000005
Logistic , ord = 4.97
Harmonic oscillator , ord = 2.84
Prothero–Robinson , ord = 4.98


Implicit Euler (jac)
exp , ord = 5.5600000000000005
Logistic , ord = 4.97
Harmonic oscillator , ord = 2.84
Prothero–Robinson , ord = 4.98




## SDC(5)

In [11]:
get_orders(5)

RK4
exp , ord = 6.49
Logistic , ord = 1.46
Harmonic oscillator , ord = 8.91
Prothero–Robinson , ord = 6.7


Explicit Euler
exp , ord = 6.2
Logistic , ord = 5.9
Harmonic oscillator , ord = 6.84
Prothero–Robinson , ord = 6.11


Implicit Euler
exp , ord = 6.68
Logistic , ord = 5.98
Harmonic oscillator , ord = 3.54
Prothero–Robinson , ord = 6.04


Implicit Euler (jac)
exp , ord = 6.68
Logistic , ord = 5.98
Harmonic oscillator , ord = 3.54
Prothero–Robinson , ord = 6.04




So, roughly we have $+1$ order of convergence per one SDC iteration.

Note that for RK4 we observe saturation. This is because already for $N=20$ points we already have too good result with enough SDC iterations.

The only exception is implicit Euler for harmonic oscillator. Need to read more on symplectic integrators to resolve this.