In [281]:
import numpy as np
import jax.numpy as jnp
import jax
from functools import partial

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

In [282]:
# @jax.jit
# @partial(jax.jit, static_argnames=['ndim', 'minpts', 'maxpts', 'functn'])
#def jax_cubature(a, b, minpts, maxpts, functn, rel_tol):
#    ndim = len(a)
#@partial(jax.jit, static_argnames=['ndim', 'minpts', 'maxpts', 'functn'])
def jax_cubature(*, ndim : int, a : jnp.ndarray, b : jnp.ndarray, minpts : int, maxpts : int, functn : callable, rel_tol : float) -> tuple:
    
    if ndim < 2:
        raise ValueError("ndim must be greater than 2")
    
    rgnstr  = 2*ndim + 2
    divaxo  = 0

    #Initialize the cubature
    rulcls,twondim,ratio,lambdas,weights,weightsp = initialise(ndim)
    
    z      = jnp.zeros(ndim)
    center = jnp.zeros(ndim)
    width  = jnp.zeros(ndim)
    widthl = jnp.zeros(ndim)

    lenwrk = (2*ndim+3)*(1+maxpts//rulcls)//2
    wrkstr = jnp.zeros(lenwrk+1)
    subtmp = 0
    
    funcls = 0
    finest = 0.0
    ifail  = 3
    

    width  = (b-a)/2.0
    center = a + width

    finest = 0.0
    subrgn = rgnstr
    sbrgns = rgnstr+1
    divflg = 1
    relerr = 1.0

    
    tmpflg = 0
    #while True:
    while ifail>2:    
        if tmpflg!=0:
            #If it is not the first call, prepare a new call to basic rule on each half of the subregion with the largest error
            divflg,subrgn,wrkstr,finest,divaxo,subtmp,center,width = prepare_new_call(ndim,rgnstr,wrkstr,lenwrk,finest,subrgn,center, width,divaxo)
 

        #while divflg == 0 or tmpflg==0:
        for i in range(2):
            if tmpflg==0:
                tmpflg = 1

            #Start Basic Rule
            finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn = basic_rule(functn,ndim,twondim,lenwrk,rulcls,center,width,widthl,lambdas,ratio,weights,weightsp,divaxo,z,finest,wrkstr,funcls)
            #End basic rule

            #Order and store results of basic rule
            wrkstr = order_results(divflg,wrkstr,subrgn,sbrgns,rgnstr,rgnerr,rgnval,divaxn,center,width,ndim)
           

            #When divflg=0, prepare to call basic rule on second half of subregion
            if divflg != 1:
                center = center.at[divaxo].set(center[divaxo] + 2.0*width[divaxo])
                sbrgns = sbrgns + rgnstr + 1
                subrgn = sbrgns - 1
                divflg = 1
                tmpflg = 0
            else:
                break

        #Check the convergence for possible termination.
        ifail ,relerr = check_convergence(ifail,wrkstr,lenwrk,finest,relerr,rel_tol,sbrgns,rgnstr,maxpts,minpts,funcls,rulcls)
        
        
    minpts = funcls
    wrkstr = wrkstr.at[lenwrk-1].set(sbrgns)
    return finest, relerr, minpts, ifail, wrkstr, funcls
       
        
            
        
        



In [283]:
def initialise(ndim):
    twondim = 2.0**ndim
    rulcls = 1 + (ndim*(12+(ndim-1)*(6+(ndim-2)*4)))//3

    lambda5 = 9.0/19.0    
    ratio = (ndim-2)/9.0
    lambda4 = (1.0/5.0 -ratio)/(1.0/3.0 -ratio/lambda5)
    ratio = (1.0 -lambda4/lambda5)*(ndim-1)*ratio/6.0
    lambda2 = (1.0/7.0 -lambda4/5.0 -ratio)/(1.0/5.0 -lambda4/3.0 -ratio/lambda5)
    
    weight5 = 1.0/(6.0*lambda5)**3
    weight4 = (1.0/15.0 -lambda5/9.0)/(4.0*(lambda4-lambda5)*lambda4**2)
    weight3 = (1.0/7.0 -(lambda5+lambda2)/5.0 +lambda5*lambda2/3.0)/(2.0*lambda4*(lambda4-lambda5)*(lambda4-lambda2)) -2.0*(ndim-1)*weight4
    weight2 = (1.0/7.0 -(lambda5+lambda4)/5.0 +lambda5*lambda4/3.0)/(2.0*lambda2*(lambda2-lambda5)*(lambda2-lambda4)) 

    if ndim<=15:
        weight1 = 1.0 -2.0*ndim*(weight2+weight3+(ndim-1)*weight4)-twondim*weight5
    else:
        weight1 = 1.0 -ndim*(weight2+weight3+(ndim-1)*(weight4+2.0*(ndim-2)*weight5/3.0))

    weight4p = 1.0/(6.0*lambda4)**2
    weight3p = (1.0/5.0 -lambda2/3.0)/(2.0*lambda4*(lambda4-lambda2)) -2.0*(ndim-1)*weight4p
    weight2p = (1.0/5.0 -lambda4/3.0)/(2.0*lambda2*(lambda2-lambda4))
    weight1p = 1.0 -2.0*ndim*(weight2p+weight3p+(ndim-1)*weight4p)

    ratio = lambda2/lambda4

    lambda5 = jnp.sqrt(lambda5)
    lambda4 = jnp.sqrt(lambda4)
    lambda2 = jnp.sqrt(lambda2)

    lambdas  = jnp.array([lambda2, lambda4, lambda5])
    weights  = jnp.array([weight1, weight2, weight3, weight4, weight5])
    weightsp = jnp.array([weight1p, weight2p, weight3p, weight4p])

    return rulcls,twondim,ratio,lambdas,weights,weightsp


In [284]:
def prepare_new_call(ndim,rgnstr,wrkstr,lenwrk,finest,subrgn,center, width,divaxo):
    divflg = 0
    subrgn = rgnstr
    wrkstr = wrkstr.at[lenwrk].set(wrkstr[lenwrk] - wrkstr[subrgn])
    finest = finest - wrkstr[subrgn-1]
    divaxo = jnp.int32(wrkstr[subrgn-2])

    for j in range(ndim):
        subtmp = subrgn-2*(j+2)
        center = center.at[j].set(wrkstr[subtmp+1])
        width  = width.at[j].set(wrkstr[subtmp])
        
    width  = width.at[divaxo].set(width[divaxo]/2.0)
    center = center.at[divaxo].set(center[divaxo]-width[divaxo])

    return divflg,subrgn,wrkstr,finest,divaxo,subtmp,center,width


In [285]:
def basic_rule(functn,ndim,twondim,lenwrk,rulcls,center,width,widthl,lambdas,ratio,weights,weightsp,divaxo,z,finest,wrkstr,funcls):
    rgnvol = twondim
    for j in range(ndim):
        rgnvol = rgnvol*width[j]
        z = z.at[j].set(center[j])

    sum1 = functn(z)
    #Compute the symetric sums of functn(lambda2,0,0,..0) and functn(lambda4,0,0,..0), and 
    #maximum fourth difference
    difmax = -1.0
    sum2 = 0.0
    sum3 = 0.0
    for j in range(ndim):
        z = z.at[j].set(center[j]-lambdas[0]*width[j])
        f1 = functn(z)
        z = z.at[j].set(center[j]+lambdas[0]*width[j])
        f2 = functn(z)
        widthl = widthl.at[j].set(lambdas[1]*width[j])
        z = z.at[j].set(center[j]-widthl[j])
        f3 = functn(z)
        z= z.at[j].set(center[j]+widthl[j])
        f4 = functn(z)
        sum2 = sum2 + f1 + f2
        sum3 = sum3 + f3 + f4
        df1 = f1+f2-2.0*sum1
        df2 = f3+f4-2.0*sum1
        dif = jnp.fabs(df1-ratio*df2)
        
        if difmax<dif:
            difmax = dif
            divaxn = j

        #difmax, divaxn = jax.lax.cond(
        #difmax < dif,
        #(difmax, dif, j),  # operands for true branch
        #lambda operands: (operands[1], operands[2]),  # update difmax to dif, divaxn to j
        #(difmax, dif, j),  # operands for false branch
        #lambda operands: (operands[0], operands[2])   # keep difmax and divaxn unchanged
        #)

        z = z.at[j].set(center[j])

    if sum1 == sum1+difmax/8.0:
        divaxn = (divaxo+1)%ndim
    #divaxn = jax.lax.cond(
    #sum1 == sum1 + difmax / 8.0,  # condition
    #None,  # true branch operand (not used in this case)
    #lambda _: (divaxo + 1) % ndim,  # true branch action
    #None,  # false branch operand (not used in this case)
    #lambda _: divaxo  # false branch action: return divaxo unchanged
    #)


    #Compute the symetric sums of functn(lambda4,lambda4,0,..0)
    sum4 = 0.0
    for j in range(1,ndim):
        for k in range(j,ndim):
            for l in range(2):
                widthl = widthl.at[j-1].set(-widthl[j-1])
                z = z.at[j-1].set(center[j-1]+widthl[j-1])

                for m in range(2):
                    widthl = widthl.at[k].set(-widthl[k])
                    z = z.at[k].set(center[k]+widthl[k])
                    f1 = functn(z)
                    sum4 = sum4 + f1
            
            z = z.at[k].set(center[k])
        z = z.at[j-1].set(center[j-1])

    #Compute symmetric sum of functn(lambda5,lambda5,lambda5,0,0...0)
    sum5 = 0.0

    for j in range(ndim):
        widthl = widthl.at[j].set(lambdas[2]*width[j])
    for i in range(2,ndim):
        for j in range(i,ndim):
            for k in range(j,ndim):
                for l in range(2):
                    widthl = widthl.at[i-2].set(-widthl[i-2])
                    z = z.at[i-2].set(center[i-2]+widthl[i-2])
                    for m in range(2):
                        widthl = widthl.at[j-1].set(-widthl[j-1])
                        z = z.at[j-1].set(center[j-1]+widthl[j-1])
                        for n in range(2):
                            widthl = widthl.at[k].set(-widthl[k])
                            z = z.at[k].set(center[k]+widthl[k])
                            f1 = functn(z)
                            sum5 = sum5 + f1
                
                z = z.at[k].set(center[k])
            z = z.at[j-1].set(center[j-1])
        z = z.at[i-2].set(center[i-2])

    #Compute fifth and seventh degree rules and error.
    rgncmp = rgnvol *(weightsp[0]*sum1 + weightsp[1]*sum2 + weightsp[2]*sum3 + weightsp[3]*sum4)
    rgnval = rgnvol *(weights[0]*sum1 + weights[1]*sum2 + weights[2]*sum3 + weights[3]*sum4 + weights[4]*sum5)
    rgnerr = jnp.fabs(rgnval-rgncmp)

    finest = finest+rgnval
    wrkstr = wrkstr.at[lenwrk].set(wrkstr[lenwrk]+rgnerr)
    funcls = funcls+rulcls

    return finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn


In [286]:
def order_results(divflg,wrkstr,subrgn,sbrgns,rgnstr,rgnerr,rgnval,divaxn,center,width,ndim):
    #When divflg=0, start at top of list and move down
    #list tree to find correct position for results from 
    #first half of recently divided subregion
    if  divflg !=1:
        subtmp = 2*subrgn +1
        while (subtmp<sbrgns and rgnerr<wrkstr[subtmp]):
            
            if subtmp!=sbrgns-1:
                sbtmpp = subtmp+rgnstr+1
                if wrkstr[subtmp]<wrkstr[sbtmpp]:
                    subtmp = sbtmpp
            
            for k in range(rgnstr+1):
                wrkstr = wrkstr.at[subrgn-k].set(wrkstr[subtmp-k])

            subrgn = subtmp
            subtmp = 2*subrgn +1

    #When divflg=1, start at bottom right branch and move
    #up list tree to find correct position for results from
    #second half of recently divided subregion
    else: 
        subtmp = ((subrgn+1)//(2*(rgnstr+1)))*(rgnstr+1)-1

        while subtmp>=rgnstr and rgnerr > wrkstr[subtmp]:
            for k in range(rgnstr+1):
                wrkstr = wrkstr.at[subrgn-k].set(wrkstr[subtmp-k])
            subrgn = subtmp
            subtmp = ((subrgn+1)//(2*(rgnstr+1)))*(rgnstr+1)-1
            
    
    #Store results of basic rule in correct position in list
    wrkstr = wrkstr.at[subrgn].set(rgnerr)
    wrkstr = wrkstr.at[subrgn-1].set(rgnval)
    wrkstr = wrkstr.at[subrgn-2].set(divaxn)
    for j in range(ndim):
        subtmp = subrgn-2*(j+2)
        wrkstr = wrkstr.at[subtmp+1].set(center[j])
        wrkstr = wrkstr.at[subtmp].set(width[j])

    return wrkstr

In [287]:
def check_convergence(ifail,wrkstr,lenwrk,finest,relerr,rel_tol,sbrgns,rgnstr,maxpts,minpts,funcls,rulcls):
    if wrkstr[lenwrk] <= 0.0:
        wkrstr = wrkstr.at[lenwrk].set(0.0)
    #wkrstr = jax.lax.cond(
    #wrkstr[lenwrk] <= 0.0,
    #lambda _: wrkstr.at[lenwrk].set(0.0),  # Set to 0 if condition is true
    #lambda _: wrkstr,  # Else, do nothing
    #None
    #)

    if jnp.fabs(finest) != 0.0:
        relerr = wrkstr[lenwrk]/jnp.fabs(finest)
    #relerr = jax.lax.cond(
    #jnp.fabs(finest) != 0.0,
    #lambda _: wrkstr[lenwrk] / jnp.fabs(finest),  # Compute relerr
    #lambda _: 0.0,  # Default to 0.0 if finest is zero
    #None
    #)

    if 1.0 < relerr:
        relerr = 1.0
    #relerr = jax.lax.cond(
    #relerr > 1.0,
    #lambda _: 1.0,
    #lambda _: relerr,
    #None
    #)

    if lenwrk<(sbrgns+rgnstr+2):
        ifail=2
    #ifail = jax.lax.cond(
    #lenwrk < (sbrgns + rgnstr + 2),
    #lambda _: 2,
    #lambda _: ifail,
    #None
    #)


    if maxpts<(funcls+2*rulcls):
        ifail=1
    #ifail = jax.lax.cond(
    #maxpts < (funcls + 2 * rulcls),
    #lambda _: 1,
    #lambda _: ifail,
    #None
    #)

    if (relerr<rel_tol) and (minpts<=funcls):
        ifail=0
    #ifail = jax.lax.cond(
    #relerr < rel_tol,
    #lambda _: jax.lax.cond(
    #    minpts <= funcls,
    #    lambda _: 0,  # Both conditions are satisfied
    #    lambda _: ifail,  # Only the first condition is satisfied
    #    None
    #),
    #lambda _: ifail,  # The first condition is not satisfied
    #None
    #)
    return ifail,relerr

In [288]:
from cubature import cubature
from numba import jit

@jit(nopython=True)
def test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    return x**2 +np.log10(y+2)**2.5 + x*z**np.log(2)

def jax_test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    return x**2 +jnp.log10(y+2)**2.5 + x*z**jnp.log(2)

# Integration limits
a = np.array([0, 0, 0])
b = np.array([np.pi, np.pi, 1])

a_jax = jnp.array([0, 0, 0])
b_jax = jnp.array([jnp.pi, jnp.pi, 1])

In [289]:
import numpy as npy
result_cb, error_cb = cubature(test_function, ndim=3, fdim=1, xmin=np.array(a), xmax=np.array(b),
                               relerr=1e-12)
print(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')

Result with cubature: [43.89958661], Estimated error: [9.83865983e-09]


In [292]:
# Uso de numpy_cubature
minpts = 10
maxpts = 12000
rel_tol = 1e-10
ndim = 3
#finest, relerr, minpts, ifail, wrkstr, funcls = numpy_cubature(a,b, minpts, maxpts, test_function, rel_tol)
finest, relerr, minpts, ifail, wrkstr, funcls = jax_cubature(ndim=ndim,a=a_jax,b=b_jax, minpts=minpts, maxpts=maxpts, functn=jax_test_function, rel_tol=rel_tol)

print(f'Result with numpy_cubature: {finest}, Estimated error: {relerr}, ifail : {ifail}')
print(funcls)

Result with numpy_cubature: 43.89958660568408, Estimated error: 9.987851301261723e-11, ifail : 0
10263
