In [1]:
from jax_cubature_1d import jax_cubature

import jax
import jax.numpy as jnp
import numpy as np
from cubature import cubature

In [2]:
#Tests from https://github.com/Areustle/cubepy/blob/main/tests/test_integrand.py
#https://www.sciencedirect.com/science/article/pii/0771050X7690005X

def test_polynomial():
    def poly(x):
        return 2.0 * jnp.pi * x**4 - jnp.e * x**3 + 3.0 * x**2 - 4.0 * x + 8.0
    def exact_poly(a,b):
        def exact(x):
            return (
                (2.0 * jnp.pi / 5.0) * x**5
                - jnp.e / 4.0 * x**4
                + x**3
                - 2.0 * x**2
                + 8.0 * x
            )
        return exact(b) - exact(a)
    
    lower = jnp.array([-1.0])
    upper = jnp.array([1.0])
    result, error = jax_cubature(poly, lower, upper, 1)
    print("Test polynomial:")
    print(jnp.allclose(result, exact_poly(lower, upper)))
    print(jnp.all(error<1e-5))
    print("")

def test_high_polynomial():
    def high_poly(x):
        return 2.0 * jnp.pi * x**20 - jnp.e * x**3 + 3.0 * x**2 - 4.0 * x + 8.0
    def exact_high_poly(a,b):
        def exact(x):
            return (
                (2.0 * jnp.pi / 21.0) * x**21
                - jnp.e / 4.0 * x**4
                + x**3
                - 2.0 * x**2
                + 8.0 * x
            )
        return exact(b) - exact(a)

    lower = jnp.array([-1.0])
    upper = jnp.array([1.0])
    result, error = jax_cubature(high_poly, lower, upper, 1)
    print("Test high polynomial:")
    print(jnp.allclose(result, exact_high_poly(lower, upper)))
    print(jnp.all(error<1e-5))
    print("")

def test_pi():
    def integrand(x):
        return 2*jnp.sqrt(1 - x**2)
    lower = jnp.array([-1.0])
    upper = jnp.array([1.0])
    result, error = jax_cubature(integrand, lower, upper,1)
    print("Test pi:")
    print(jnp.allclose(result, jnp.pi))
    print(jnp.all(error<1e-5))
    print("")

def test_sinc():
    def integrand(x):
        return jnp.sinc(x)
    lower = jnp.array([-1000.0])
    upper = jnp.array([1000.0])
    result, error = jax_cubature(integrand, lower, upper,1,maxpts=100000)
    print("Test sinc:")
    print(jnp.allclose(result, 0.9997973576737792))
    print(jnp.all(error<1e-5)) 
    print("")

def test_quadratic():
    lower = jnp.array([-1,-1])
    upper = jnp.array([1,1])
    def f(x):
        return jnp.sum(x**2)
    result, error = jax_cubature(functn=f, a=lower, b=upper,ndim=2)
    #assert jnp.allclose(result, 8/3)
    #assert jnp.allclose(error, 0)
    print("Test quadratic:")
    print(jnp.allclose(result, 8/3))
    print(jnp.allclose(error, 0))
    print("")

def test_brick():
    lower = jnp.array([0,0,0])
    upper = jnp.array([1,2,3])

    def integrand_brick(x):
        return 1

    def exact_brick(r):
        return jnp.prod(r,axis=0)

    result, error = jax_cubature(functn=integrand_brick, a=lower, b=upper, ndim=3)
    print("Test brick:")
    print(jnp.allclose(result, exact_brick(upper)))
    print(jnp.all(error<1e-5))
    print("")

def test_sphere():
    lower = jnp.array([0,0,0])
    upper = jnp.array([1,jnp.pi,jnp.pi*2])

    def exact_sphere(r):
        return 4/3*jnp.pi*r**3
    
    def integrand_sphere(x):
        r = x[0]
        theta = x[1]
        phi = x[2]
        return r**2*jnp.sin(theta)
    
    result, error = jax_cubature(functn=integrand_sphere, a=lower, b=upper, ndim=3)
    print("Test sphere:")
    print(jnp.allclose(result, exact_sphere(upper[0])))
    print(jnp.all(error<1e-5))
    print("")

def test_ellipsoid():
    lower = jnp.array([0,0,0])
    upper = jnp.array([1,2*jnp.pi,jnp.pi])
    
    def exact_ellipsoid(r):
        return 4/3*jnp.pi*jnp.prod(r,axis=0)
    
    def integrand_ellipsoid(x,args):
        rho = x[0]
        phi = x[1]
        theta = x[2]

        a = args[0]
        b = args[1]
        c = args[2]

        return a*b*c* rho**2 *jnp.sin(theta)
    
    args = jnp.array([1,2,3])
    result, error = jax_cubature(integrand_ellipsoid,lower,upper, 3, args)
    print("Test ellipsoid:")
    print(jnp.allclose(result, exact_ellipsoid(args)))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_1():
    lower = jnp.array([0.0, 0.0, 0.0, -1.0, -1.0, -1.0])
    upper = jnp.array([2.0, 1.0, (jnp.pi / 2.0), 1.0, 1.0, 1.0])

    def integrand_van_dooren_de_riddler_1(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]
        x3 = x_arr[3]
        x4 = x_arr[4]
        x5 = x_arr[5]

        return (x0*x1**2 *jnp.sin(x2))/(4+ x3 + x4 + x5)
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_1, lower, upper, 6)#,maxpts=500000)
    print("Test van_dooren_de_riddler_1:")
    print(jnp.allclose(result, 1.434761888397263))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_2():
    lower = jnp.array([0.0, 0.0, 0.0, 0.0])    
    upper = jnp.array([1.0, 1.0, 1.0, 2.0])

    def integrand_van_dooren_de_riddler_2(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]
        x3 = x_arr[3]

        return x2**2 * x3 * jnp.exp(x2 * x3) * (1 + x0 + x1) ** -2
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_2, lower, upper, 4,maxpts=20000)
    print("Test van_dooren_de_riddler_2:")
    print(jnp.allclose(result, 0.5753641449035616))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_3():
    lower = jnp.array([0.0, 0.0, 0.0])    
    upper = jnp.array([1.0, 1.0, 1.0])
    
    def integrand_van_dooren_de_riddler_3(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]

        return 8 / (1 + 2*(x0+x1+x2))
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_3, lower, upper, 3)
    print("Test van_dooren_de_riddler_3:")
    print(jnp.allclose(result, 2.152142832595894))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_4():
    lower = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0])
    upper = jnp.array([jnp.pi, jnp.pi, jnp.pi, jnp.pi, 0.5*jnp.pi])

    def integrand_van_dooren_de_riddler_4(x_arr): 
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]
        x3 = x_arr[3]
        x4 = x_arr[4]
        return jnp.cos(x0+x1+x2+x3+x4)
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_4, lower, upper, 5, maxpts=150000)
    print("Test van_dooren_de_riddler_4:")
    print(jnp.allclose(result, 16.0))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_5():
    lower = jnp.array([0.0, 0.0, 0.0, 0.0])
    upper = jnp.array([1.0, 1.0, 1.0, 1.0])

    def integrand_van_dooren_de_riddler_5(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]
        x3 = x_arr[3]
        return jnp.sin(10*x0)
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_5, lower, upper, 4)
    print("Test van_dooren_de_riddler_5:")
    print(jnp.allclose(result, 0.1839071529076452))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_6():
    lower = jnp.array([0.0, 0.0])
    upper = jnp.array([3.0*jnp.pi, 3.0*jnp.pi])

    def integrand_van_dooren_de_riddler_6(x_arr): #Very very slow convergence
        x0 = x_arr[0]
        x1 = x_arr[1]
        return jnp.cos(x0+x1)
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_6, lower, upper, 2)#, maxpts=100000)
    print("Test van_dooren_de_riddler_6:")
    print(jnp.allclose(result, -4.0))
    print(jnp.all(error<1e-5))
    print("")
    
def test_van_dooren_de_riddler_7():
    lower = jnp.array([0.0, 0.0, 0.0])
    upper = jnp.array([1.0, 1.0, 1.0])

    def integrand_van_dooren_de_riddler_7(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        x2 = x_arr[2]
        return (x0+x1+x2)**-2
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_7, lower, upper, 3, maxpts=20000)
    print("Test van_dooren_de_riddler_7:")
    print(jnp.allclose(result, 0.8630462173553432))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_8():
    lower = jnp.array([0.0, 0.0])
    upper = jnp.array([1.0, 1.0])

    def integrand_van_dooren_de_riddler_8(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        
        return 605*x1 *((1+120*(1-x1))*((1+120*(1-x1))**2 +25*x1**2 *x0**2))**-1
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_8, lower, upper, 2)#, maxpts=100000)
    print("Test van_dooren_de_riddler_8:")
    print(jnp.allclose(result, 1.047591113142868))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_9():
    lower = jnp.array([0.0, 0.0])
    upper = jnp.array([1.0, 1.0])

    def integrand_van_dooren_de_riddler_9(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        return 1/((x0**2 + 1.0e-4) * ((x1 + 0.25)**2 +1.0e-4))
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_9, lower, upper, 2)
    print("Test van_dooren_de_riddler_9:")
    print(jnp.allclose(result, 499.1249442241215))
    print(jnp.all(error<1e-5))
    print("")

def test_van_dooren_de_riddler_10():
    lower = jnp.array([0.0, 0.0])
    upper = jnp.array([1.0, 1.0])

    def integrand_van_dooren_de_riddler_10(x_arr):
        x0 = x_arr[0]
        x1 = x_arr[1]
        return jnp.exp(jnp.abs((x0+x1) -1.0))
    
    result, error = jax_cubature(integrand_van_dooren_de_riddler_10, lower, upper, 2)#, maxpts=200000)
    print("Test van_dooren_de_riddler_10:")
    print(jnp.allclose(result, 1.436563656918090))
    print(jnp.all(error<1e-5))
    print("")

def test_gradient_1d():
    def f_fun(x):
        return x**2 + jnp.log1p(x)**2.5 + x*jnp.log(2)
    
    def g_fun(x):
        return x**2 + np.log1p(x)**2.5 + x*np.log(2)

    def f(x):
        lower = jnp.array([0])
        upper = jnp.array([x])  
        result, error = jax_cubature(f_fun, lower, upper, 1)
        return result
    
    def g(x):
        lower = np.array([0])
        upper = np.array([x])  
        result, error = cubature(g_fun, 1,1 , lower, upper, relerr=1e-8)
        return result
    
    point = 1.0
    h = 1e-5
    df = jax.jacfwd(f)
    dg = (g(point+h) - g(point-h)) / (2*h)
    print("Test gradient 1-D:")
    print(jnp.allclose(df(point), dg, atol=1e-5))
    print(jnp.all(jnp.abs(df(point) - dg) < 1e-5))
    print("")
  
def test_gradient_nd():
    def f_fun(x_array):
        x = x_array[0]
        y = x_array[1]
        z = x_array[2]
        return x**2 +jnp.log10(y+2)**2.5 + x*z**jnp.log(2)
    
    def g_fun(x_array):
        x = x_array[0]
        y = x_array[1]
        z = x_array[2]
        return x**2 +np.log10(y+2)**2.5 + x*z**np.log(2)

    def f(x):
        lower = jnp.array([0, 0, 0])
        upper = jnp.array([jnp.pi, jnp.pi, x])  
        result, error = jax_cubature(f_fun, lower, upper, 3)
        return result
    
    def g(x):
        lower = np.array([0, 0, 0])
        upper = np.array([np.pi, np.pi, x])  
        result, error = cubature(g_fun, 3,1 , lower, upper, relerr=1e-8)
        return result
    
    point = 1.0
    h = 1e-5
    df = jax.jacfwd(f)
    dg = (g(point+h) - g(point-h)) / (2*h)
    print("Test gradient N-D:")
    print(jnp.allclose(df(point), dg, atol=1e-5))
    print(jnp.all(jnp.abs(df(point) - dg) < 1e-5))
    print("")

In [3]:
#1D tests
print("1-D tests:")
test_polynomial()
test_high_polynomial() 
test_pi()
test_sinc()
print("\n")

#ND tests
print("ND tests:")
test_quadratic()
test_brick()
test_sphere()
test_ellipsoid()
test_van_dooren_de_riddler_1()
test_van_dooren_de_riddler_2()
test_van_dooren_de_riddler_3()
test_van_dooren_de_riddler_4()
test_van_dooren_de_riddler_5()
test_van_dooren_de_riddler_6()
test_van_dooren_de_riddler_7()
test_van_dooren_de_riddler_8()
test_van_dooren_de_riddler_9()
test_van_dooren_de_riddler_10()
print("\n")


#Gradient tests
print("Gradient tests:")
test_gradient_1d()
test_gradient_nd()

1-D tests:
Test polynomial:
True
True

Test high polynomial:
True
True

Test pi:
True
True

Test sinc:
True
True



ND tests:
Test quadratic:
True
True

Test brick:
True
True

Test sphere:
True
True

Test ellipsoid:
True
True

Test van_dooren_de_riddler_1:
True
True

Test van_dooren_de_riddler_2:
True
True

Test van_dooren_de_riddler_3:
True
True

Test van_dooren_de_riddler_4:
True
True

Test van_dooren_de_riddler_5:
True
True

Test van_dooren_de_riddler_6:
True
True

Test van_dooren_de_riddler_7:
True
True

Test van_dooren_de_riddler_8:
True
True

Test van_dooren_de_riddler_9:
True
True

Test van_dooren_de_riddler_10:
True
True



Gradient tests:
Test gradient 1-D:
True
True

Test gradient N-D:
True
True



In [4]:
#