In [1]:
import numpy as np
import jax.numpy as jnp
import jax
from jax import lax
from functools import partial

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

In [2]:
#@partial(jax.jit, static_argnames=("ndim","functn","a","b","minpts","maxpts","rel_tol" ))
@partial(jax.jit, static_argnames=("ndim","functn","maxpts"))
def jax_cubature(*, ndim : int, a : jnp.ndarray, b : jnp.ndarray, minpts : int, maxpts : int, functn : callable, rel_tol : float) -> tuple:
    
    if ndim < 2:
        raise ValueError("ndim must be greater than 2")
    
    rgnstr  = 2*ndim + 2
    divaxo  = 0
    divaxn  = 0

    #Initialize the cubature
    rulcls,twondim,ratio,lambdas,weights,weightsp = initialise(ndim)
    
    z      = jnp.zeros(ndim)
    center = jnp.zeros(ndim)
    width  = jnp.zeros(ndim)
    widthl = jnp.zeros(ndim)

    lenwrk = (2*ndim+3)*(1+maxpts//rulcls)//2
    wrkstr = jnp.zeros(lenwrk+1)
    subtmp = 0
    
    funcls = 0
    finest = 0.0
    ifail  = 3
    

    width  = (b-a)/2.0
    center = a + width

    finest = 0.0
    subrgn = rgnstr
    sbrgns = rgnstr+1
    divflg = 1  
    relerr = 1.0

    #maxcalls = maxpts//rulcls
    #maxcalls = 10

    #Initial call to basic rule
    finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn = basic_rule(functn,ndim,twondim,lenwrk,rulcls,center,width,widthl,lambdas,ratio,weights,weightsp,divaxo,divaxn,z,finest,wrkstr,funcls)
    #Order and store results of basic rule
    wrkstr = order_results(divflg,wrkstr,subrgn,sbrgns,rgnstr,rgnerr,rgnval,divaxn,center,width,ndim)
    #Check the convergence for possible termination.
    relerr = jnp.where(jnp.fabs(finest) != 0.0, wrkstr[lenwrk] / jnp.fabs(finest), 1.0)

    def loop_cond(state):
        relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z = state
        return relerr > rel_tol

    def update_state(state):
    #while relerr > rel_tol:
        relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z = state

        divflg, subrgn, wrkstr, finest, divaxo, subtmp, center, width = prepare_new_call(ndim, rgnstr, wrkstr, lenwrk, finest, subrgn, center, width, divaxo)

        finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn = basic_rule(functn, ndim, twondim, lenwrk, rulcls, center, width, widthl, lambdas, ratio, weights, weightsp, divaxo, divaxn, z, finest, wrkstr, funcls)

        wrkstr = order_results(divflg, wrkstr, subrgn, sbrgns, rgnstr, rgnerr, rgnval, divaxn, center, width, ndim)

        center = center.at[divaxo].set(center[divaxo] + 2.0 * width[divaxo])
        sbrgns = sbrgns + rgnstr + 1
        subrgn = sbrgns - 1

        finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn = basic_rule(functn, ndim, twondim, lenwrk, rulcls, center, width, widthl, lambdas, ratio, weights, weightsp, divaxo, divaxn, z, finest, wrkstr, funcls)

        wrkstr = order_results(divflg, wrkstr, subrgn, sbrgns, rgnstr, rgnerr, rgnval, divaxn, center, width, ndim)

        relerr = jnp.where(jnp.abs(finest) != 0.0, wrkstr[lenwrk] / jnp.abs(finest), 1.0)

        return (relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z)

    def no_op(state):
        relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z = state
        return (relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z)

    # Assuming the initial state is properly initialized
    initial_state = (relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z)

    for _ in range(maxpts):
        initial_state = jax.lax.cond(loop_cond(initial_state), update_state, no_op, initial_state)

    relerr, rel_tol, funcls, divflg, finest, wrkstr, sbrgns, subrgn, center, width, divaxo, divaxn, z = initial_state

    minpts = funcls
    wrkstr = wrkstr.at[lenwrk-1].set(sbrgns)
    return finest, relerr, minpts, ifail, wrkstr, funcls

In [3]:
#@partial(jax.jit, static_argnames=['ndim'])
def initialise(ndim):
    twondim = 2.0**ndim

    lambda5 = 9.0/19.0   
    if ndim<=15: 
    #if ndim <= 2:
        rulcls = 2**ndim + 2*ndim*ndim + 2*ndim +1
        lambda4 = 9.0/10.0
        lambda2 = 9.0/70.0
        weight5 = 1.0/(3.0*lambda5)**3 /twondim
    else:
        rulcls = 1 + (ndim*(12+(ndim-1)*(6+(ndim-2)*4)))//3
        ratio = (ndim-2)/9.0
        lambda4 = (1.0/5.0 -ratio)/(1.0/3.0 -ratio/lambda5)
        ratio = (1.0 -lambda4/lambda5)*(ndim-1)*ratio/6.0
        lambda2 = (1.0/7.0 -lambda4/5.0 -ratio)/(1.0/5.0 -lambda4/3.0 -ratio/lambda5)
        weight5 = 1.0/(6.0*lambda5)**3

    weight4 = (1.0/15.0 -lambda5/9.0)/(4.0*(lambda4-lambda5)*lambda4**2)
    weight3 = (1.0/7.0 -(lambda5+lambda2)/5.0 +lambda5*lambda2/3.0)/(2.0*lambda4*(lambda4-lambda5)*(lambda4-lambda2)) -2.0*(ndim-1)*weight4
    weight2 = (1.0/7.0 -(lambda5+lambda4)/5.0 +lambda5*lambda4/3.0)/(2.0*lambda2*(lambda2-lambda5)*(lambda2-lambda4)) 

    if ndim<=15:
        weight1 = 1.0 -2.0*ndim*(weight2+weight3+(ndim-1)*weight4)-twondim*weight5
    else:
        weight1 = 1.0 -ndim*(weight2+weight3+(ndim-1)*(weight4+2.0*(ndim-2)*weight5/3.0))

    weight4p = 1.0/(6.0*lambda4)**2
    weight3p = (1.0/5.0 -lambda2/3.0)/(2.0*lambda4*(lambda4-lambda2)) -2.0*(ndim-1)*weight4p
    weight2p = (1.0/5.0 -lambda4/3.0)/(2.0*lambda2*(lambda2-lambda4))
    weight1p = 1.0 -2.0*ndim*(weight2p+weight3p+(ndim-1)*weight4p)

    ratio = lambda2/lambda4

    lambda5 = jnp.sqrt(lambda5)
    lambda4 = jnp.sqrt(lambda4)
    lambda2 = jnp.sqrt(lambda2)

    lambdas  = jnp.array([lambda2, lambda4, lambda5])
    weights  = jnp.array([weight1, weight2, weight3, weight4, weight5])
    weightsp = jnp.array([weight1p, weight2p, weight3p, weight4p])

    return rulcls,twondim,ratio,lambdas,weights,weightsp

In [4]:
#@partial(jax.jit, static_argnames=['ndim'])
def prepare_new_call(ndim,rgnstr,wrkstr,lenwrk,finest,subrgn,center, width,divaxo):
    divflg = 0
    subrgn = rgnstr
    wrkstr = wrkstr.at[lenwrk].set(wrkstr[lenwrk] - wrkstr[subrgn])
    finest = finest - wrkstr[subrgn-1]
    divaxo = jnp.int64(wrkstr[subrgn-2])

    for j in range(ndim):
        subtmp = subrgn-2*(j+2)
        center = center.at[j].set(wrkstr[subtmp+1])
        width  = width.at[j].set(wrkstr[subtmp])
        
    width  = width.at[divaxo].set(width[divaxo]/2.0)
    center = center.at[divaxo].set(center[divaxo]-width[divaxo])

    return divflg,subrgn,wrkstr,finest,divaxo,subtmp,center,width

In [5]:
#def basic_rule(functn,ndim,twondim,lenwrk,rulcls,center,width,widthl,lambdas,ratio,weights,weightsp,divaxo,divaxn,z,finest,wrkstr,funcls):
#def basic_rule(*, functn : callable, ndim : int, twondim : float, lenwrk : int, rulcls : int, center : jnp.ndarray, width : jnp.ndarray, widthl : jnp.ndarray, lambdas : jnp.ndarray, ratio : float, weights : jnp.ndarray, weightsp : jnp.ndarray, divaxo : int, divaxn : int, z : jnp.ndarray, finest : float, wrkstr : jnp.ndarray, funcls : int) -> tuple:
#@partial(jax.jit, static_argnames=['ndim','functn'])
def basic_rule(functn,ndim,twondim,lenwrk,rulcls,center,width,widthl,lambdas,ratio,weights,weightsp,divaxo,divaxn,z,finest,wrkstr,funcls):
    rgnvol = twondim
    maxit = 34
    for j in range(ndim):
    #for j in range(ndim):
        rgnvol = rgnvol*width[j]
        z = z.at[j].set(center[j])

    sum1 = functn(z)
    #Compute the symetric sums of functn(lambda2,0,0,..0) and functn(lambda4,0,0,..0), and 
    #maximum fourth difference
    difmax = -1.0
    sum2 = 0.0
    sum3 = 0.0
    for j in range(ndim):
    #for j in range(3):
        z = z.at[j].set(center[j]-lambdas[0]*width[j])
        f1 = functn(z)
        z = z.at[j].set(center[j]+lambdas[0]*width[j])
        f2 = functn(z)
        widthl = widthl.at[j].set(lambdas[1]*width[j])
        z = z.at[j].set(center[j]-widthl[j])
        f3 = functn(z)
        z= z.at[j].set(center[j]+widthl[j])
        f4 = functn(z)
        sum2 = sum2 + f1 + f2
        sum3 = sum3 + f3 + f4
        df1 = f1+f2-2.0*sum1
        df2 = f3+f4-2.0*sum1
        dif = jnp.fabs(df1-ratio*df2)
        
        def _if_update(difmax,dif,divaxn,j):
            def _update(_):
                return jnp.int64(j)
            def _no_update(_):
                return jnp.int64(divaxn)
            divaxn = jax.lax.cond(difmax<dif,_update,_no_update,None)
            return divaxn
        divaxn = _if_update(difmax,dif,divaxn,j)

        def _if_update(difmax,dif):
            def _update(_):
                return dif
            def _no_update(_):
                return difmax
            difmax = jax.lax.cond(difmax<dif,_update,_no_update,None)
            return difmax
        
        difmax = _if_update(difmax,dif)
        z = z.at[j].set(center[j])

    def _if_cond(sum1,difmax,divaxo,divaxn,ndim):
        def _update(_):
            #return jnp.int64((divaxo+1)%ndim)
            return jnp.int64((divaxo+1)%3)
        def _no_update(_):
            return jnp.int64(divaxn)
        divaxn = jax.lax.cond(sum1 == sum1+difmax/8.0,_update,_no_update,None)
        return divaxn
    divaxn = _if_cond(sum1,difmax,divaxo,divaxn,ndim)
    
    #Compute the symetric sums of functn(lambda4,lambda4,0,..0)
    sum4 = 0.0
    for j in range(1,ndim):
        for k in range(j,ndim):
    #for j in range(1,3):
        #for k in range(j,3):
            for l in range(2):
                widthl = widthl.at[j-1].set(-widthl[j-1])
                z = z.at[j-1].set(center[j-1]+widthl[j-1])

                for m in range(2):
                    widthl = widthl.at[k].set(-widthl[k])
                    z = z.at[k].set(center[k]+widthl[k])
                    f1 = functn(z)
                    sum4 = sum4 + f1
            
            z = z.at[k].set(center[k])
        z = z.at[j-1].set(center[j-1])

    #Compute symmetric sum of functn(lambda5,lambda5,lambda5,0,0...0)
    sum5 = 0.0
    
    if ndim<=15:
    #if ndim<=2:
    #if False:
        widthl = -lambdas[2]*width
        z = center+widthl
        
        shrink = True
        #while shrink:
        #for i in range(maxit):
        def _loop(k,main_state):
            #if shrink:
            def _outer_if_true(state):
                widthl,z,sum5,shrink = state
                shrink = False
                f1 = functn(z)
                sum5 = sum5 + f1

                def _body_loop(j,state):
                    #widthl,z,flag,shrink = state
                    def _if_false(state):

                        def _inner_if_true(tmp):
                            widthl,z,flag,shrink,j = tmp
                            widthl = widthl.at[j].set(-widthl[j])
                            z = z.at[j].set(center[j]+widthl[j])
                            tmp = (widthl,z,flag,shrink,j)
                            return tmp

                        def _inner_if_false(tmp):
                            return tmp
                        
                        widthl,z,flag,shrink,j = state
                        tmp = (widthl,z,flag,shrink,j)
                        tmp = jax.lax.cond(flag,_inner_if_true,_inner_if_false,tmp)
                        widthl,z,flag,shrink,j = tmp
                        state = (widthl,z,flag,shrink,j)
                        return state    

                    def _if_true(state):
                        widthl,z,flag,shrink,j = state
                        
                        def _inner_if_true(tmp):
                            widthl,z,flag,shrink,j = tmp
                            widthl = widthl.at[j].set(-widthl[j])
                            z = z.at[j].set(center[j]+widthl[j])
                            flag = False
                            shrink = True
                            return (widthl,z,flag,shrink,j)

                        def _inner_if_false(tmp):
                            return tmp
                        
                        tmp = (widthl,z,flag,shrink,j)
                        tmp = jax.lax.cond(flag,_inner_if_true,_inner_if_false,tmp)
                        widthl,z,flag,shrink,j = tmp
                        state = (widthl,z,flag,shrink,j)
                        return state
                    
                    widthl,z,flag,shrink = state
                    j_state = (widthl,z,flag,shrink,j)
                    tmp = jax.lax.cond(widthl[j]<0.0,_if_true,_if_false,j_state)
                    widthl,z,flag,shrink,j = tmp
                    state = (widthl,z,flag,shrink)
                    return state


                flag = True
                state = (widthl,z,flag,shrink)
                state = jax.lax.fori_loop(0,ndim,_body_loop,state)
                widthl,z,flag,shrink = state
                return (widthl,z,sum5,shrink)
            
            def _outer_if_false(state):
                return state
            
            widthl,z,sum5,shrink = main_state
            state = (widthl,z,sum5,shrink)
            state = jax.lax.cond(shrink,_outer_if_true,_outer_if_false,state)
            return state 


        main_state = (widthl,z,sum5,shrink)
        main_state = jax.lax.fori_loop(0,maxit,_loop,main_state)
        widthl,z,sum5,shrink = main_state

    else:
        for j in range(ndim):
            widthl = widthl.at[j].set(lambdas[2]*width[j])
        for i in range(2,ndim):
            for j in range(i,ndim):
                for k in range(j,ndim):
                    for l in range(2):
                        widthl = widthl.at[i-2].set(-widthl[i-2])
                        z = z.at[i-2].set(center[i-2]+widthl[i-2])
                        for m in range(2):
                            widthl = widthl.at[j-1].set(-widthl[j-1])
                            z = z.at[j-1].set(center[j-1]+widthl[j-1])
                            for n in range(2):
                                widthl = widthl.at[k].set(-widthl[k])
                                z = z.at[k].set(center[k]+widthl[k])
                                f1 = functn(z)
                                sum5 = sum5 + f1
                    
                        z = z.at[k].set(center[k])
                    z = z.at[j-1].set(center[j-1])
                z = z.at[i-2].set(center[i-2])

    #Compute fifth and seventh degree rules and error.
    rgncmp = rgnvol *(weightsp[0]*sum1 + weightsp[1]*sum2 + weightsp[2]*sum3 + weightsp[3]*sum4)
    rgnval = rgnvol *(weights[0]*sum1 + weights[1]*sum2 + weights[2]*sum3 + weights[3]*sum4 + weights[4]*sum5)
    rgnerr = jnp.fabs(rgnval-rgncmp)

    finest = finest+rgnval
    wrkstr = wrkstr.at[lenwrk].set(wrkstr[lenwrk]+rgnerr)
    funcls = funcls+rulcls

    return finest, wrkstr, funcls, rgnerr, rgnval, divaxo, divaxn

In [6]:
#@partial(jax.jit, static_argnames=['ndim'])
def order_results(divflg,wrkstr,subrgn,sbrgns,rgnstr,rgnerr,rgnval,divaxn,center,width,ndim):
    
    def _place_first(state):
    #When divflg=0, start at top of list and move down
    #list tree to find correct position for results from 
    #first half of recently divided subregion
        subrgn,sbrgns,rgnstr,rgnerr,wrkstr = state
        subtmp = 2*subrgn +1
        state = (subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn)
        maxiter = 100

        def _body_while(_,state):
        #while (subtmp<sbrgns and rgnerr<wrkstr[subtmp]): 
            subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn = state           
            new_state = (subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn)
            def _outer_while_true(_):
                def _while_true(_):
                    subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn = new_state           
                    tmp_state = (subtmp,rgnstr,wrkstr)
        
                    def _true(_):
                        subtmp,rgnstr,wrkstr = tmp_state
                        sbtmpp = subtmp+rgnstr+1          
                        def _true_statement(_):
                            return sbtmpp
                        def _false_statement(_):
                            return subtmp
                        subtmp = jax.lax.cond(wrkstr[subtmp]<wrkstr[sbtmpp],_true_statement,_false_statement,None)               
                        return subtmp
                    
                    def _false(_):
                        subtmp,rgnstr,wrkstr = tmp_state
                        return subtmp
        
                    subtmp = jax.lax.cond(subtmp!=sbrgns-1,_true,_false,tmp_state)

                    def loop_body(k,tmp_state):
                        wrkstr,subrgn,subtmp = tmp_state
                        wrkstr = wrkstr.at[subrgn-k].set(wrkstr[subtmp-k])
                        return (wrkstr,subrgn,subtmp)
                    tmp_state = (wrkstr,subrgn,subtmp)
                    wrkstr,subrgn,subtmp = jax.lax.fori_loop(0,rgnstr+1,loop_body,tmp_state)        

                    subrgn = subtmp
                    subtmp = 2*subrgn +1
                    state = (subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn) 
                    return state
                
                
                def _while_false(_):
                    state = (subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn) 
                    return state

                cond2 = subtmp<sbrgns 
                state = jax.lax.cond(cond2,_while_true,_while_false,new_state)
                return state
            

            def _outer_while_false(_):
                    state = (subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn) 
                    return state


            cond1 = rgnerr<wrkstr[subtmp]
            state = jax.lax.cond(cond1,_outer_while_true,_outer_while_false,new_state)
            #subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn = state
            return state

        state = jax.lax.fori_loop(0,maxiter,_body_while,state)
        subtmp,sbrgns,rgnstr,rgnerr,wrkstr,subrgn = state
        return wrkstr,subrgn,subtmp


    def _place_second(state):
    #When divflg=1, start at bottom right branch and move
    #up list tree to find correct position for results from
    #second half of recently divided subregion
        subrgn,sbrgns,rgnstr,rgnerr,wrkstr = state
        subtmp = ((subrgn+1)//(2*(rgnstr+1)))*(rgnstr+1)-1
        maxiter = 1000
        cond1 = subtmp>=rgnstr
        cond2 = rgnerr > wrkstr[subtmp]
        state = (wrkstr,subrgn,subtmp,rgnstr)

        #while subtmp>=rgnstr and rgnerr > wrkstr[subtmp]:
            #for k in range(rgnstr+1):
        def _body_while(_,state):
            wrkstr,subrgn,subtmp,rgnstr = state
            def _outer_while_true(state):
                wrkstr,subrgn,subtmp,rgnstr = state
                def _while_true(state):
                    wrkstr,subrgn,subtmp,rgnstr = state
                    def _loop_body(k,tmp_state):
                        wrkstr,subrgn,subtmp = tmp_state
                        wrkstr = wrkstr.at[subrgn-k].set(wrkstr[subtmp-k])
                        return (wrkstr,subrgn,subtmp)
                    tmp_state = (wrkstr,subrgn,subtmp)
                    wrkstr,subrgn,subtmp = jax.lax.fori_loop(0,rgnstr+1,_loop_body,tmp_state)

                    
                    subrgn = subtmp
                    subtmp = ((subrgn+1)//(2*(rgnstr+1)))*(rgnstr+1)-1
                    state = (wrkstr,subrgn,subtmp,rgnstr)
                    return state
                def _while_false(state):
                    return state
                cond1 = subtmp>=rgnstr
                state = jax.lax.cond(cond1,_while_true,_while_false,state)

                return state
            def _outer_while_false(_):
                return state
            
            cond2 = rgnerr > wrkstr[subtmp]
            state = jax.lax.cond(cond2,_outer_while_true,_outer_while_false,state)
            return state


        state = jax.lax.fori_loop(0,maxiter,_body_while,state)
        wrkstr,subrgn,subtmp,rgnstr = state
        return wrkstr,subrgn,subtmp

 

    #if  divflg !=1:
    #    wrkstr,subrgn,subtmp = _place_first(subrgn,sbrgns,rgnstr,rgnerr,wrkstr)    

    #else: 
    #    wrkstr,subrgn,subtmp = _place_second(subrgn,rgnstr,rgnerr,wrkstr)
    state = (subrgn,sbrgns,rgnstr,rgnerr,wrkstr)
    cond = divflg!=1
    wrkstr,subrgn,subtmp = jax.lax.cond(cond,_place_first,_place_second,state)

    #Store results of basic rule in correct position in list
    wrkstr = wrkstr.at[subrgn].set(rgnerr)
    wrkstr = wrkstr.at[subrgn-1].set(rgnval)
    wrkstr = wrkstr.at[subrgn-2].set(divaxn)
    for j in range(ndim):
    #for j in range(3):
        subtmp = subrgn-2*(j+2)
        wrkstr = wrkstr.at[subtmp+1].set(center[j])
        wrkstr = wrkstr.at[subtmp].set(width[j])

    return wrkstr

In [14]:
from cubature import cubature
from numba import jit

@jit(nopython=True)
def test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    return x**2 +np.log10(y+2)**2.5 + x*z**np.log(2)
#@jax.jit
def jax_test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    return x**2 +jnp.log10(y+2)**2.5 + x*z**jnp.log(2)

# Integration limits
a = np.array([1, 0, 0])
b = np.array([np.pi, np.pi, 1])

a_jax = jnp.array([1, 0, 0])
b_jax = jnp.array([jnp.pi, jnp.pi, 1])

In [15]:
result_cb, error_cb = cubature(test_function, ndim=3, fdim=1, xmin=np.array(a), xmax=np.array(b),
                               relerr=1e-12)
print(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')

Result with cubature: [41.20097826], Estimated error: [9.72913711e-09]


In [16]:
# Uso de numpy_cubature
minpts = 1
maxpts = 20
rel_tol = 1e-4
ndim = 3
#finest, relerr, minpts, ifail, wrkstr, funcls = numpy_cubature(a,b, minpts, maxpts, test_function, rel_tol)
finest, relerr, minpts, ifail, wrkstr, funcls = jax_cubature(ndim=ndim,a=a_jax,b=b_jax, minpts=minpts, maxpts=maxpts, functn=jax_test_function, rel_tol=rel_tol)

print(f'Result with jax_cubature: {finest}, Estimated error: {relerr}, ifail : {ifail}')
print(funcls)

Result with jax_cubature: 41.201256904079976, Estimated error: 0.0, ifail : 3
99


In [17]:
def f(x):
    a_jax = jnp.array([1, 0, 0])
    b_jax = jnp.array([jnp.pi, jnp.pi, x])
    minpts = 1
    maxpts = 20
    rel_tol = 1e-4
    ndim = 3
    finest, relerr, minpts, ifail, wrkstr, funcls = jax_cubature(ndim=ndim,a=a_jax,b=b_jax, minpts=minpts, maxpts=maxpts, functn=jax_test_function, rel_tol=rel_tol)
    print(f'Result with jax_cubature: {finest}, Estimated error: {relerr}, ifail : {ifail}')

    return  finest

df = jax.jacfwd(f)
grad = df(1.0)
print(grad)

Result with jax_cubature: Traced<ConcreteArray(41.201256904079976, dtype=float64)>with<JVPTrace(level=2/0)> with
  primal = Array(41.2012569, dtype=float64)
  tangent = Traced<ShapedArray(float64[])>with<BatchTrace(level=1/0)> with
    val = Array([46.9051262], dtype=float64)
    batch_dim = 0, Estimated error: Traced<ConcreteArray(0.0, dtype=float64)>with<JVPTrace(level=2/0)> with
  primal = Array(0., dtype=float64)
  tangent = Traced<ShapedArray(float64[])>with<BatchTrace(level=1/0)> with
    val = Array([0.], dtype=float64)
    batch_dim = 0, ifail : 3
46.90512620152504


In [19]:
def g(x):
    # Integration limits    
    a = np.array([1, 0, 0])
    b = np.array([np.pi, np.pi, x])
    result_cb, error_cb = cubature(test_function, ndim=3, fdim=1, xmin=np.array(a), xmax=np.array(b),
                               relerr=1e-12)
    #print(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')
    return result_cb

dg = (g(1+1e-4)-g(1-1e-4))/(2e-4)
print(dg)

[46.90465464]


In [18]:
grad

Array(46.9051262, dtype=float64)

In [10]:
"""
@jit(nopython=True)
def test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    n = x_array[3]
    return x**1/2*n**2 +np.log(y*n+2)**2.5 + x*z**np.log(2) 
@jax.jit
def jax_test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    n = x_array[3]
    return x**1/2*n**2 +jnp.log(y*n+2)**2.5 + x*z**jnp.log(2)

# Integration limits
a = np.array([0, 0, 0,0])
b = np.array([np.pi, np.pi, 1, 1])

a_jax = jnp.array([0, 0, 0,0])
b_jax = jnp.array([jnp.pi, jnp.pi, 1,1])
"""

'\n@jit(nopython=True)\ndef test_function(x_array):\n    x = x_array[0]\n    y = x_array[1]\n    z = x_array[2]\n    n = x_array[3]\n    return x**1/2*n**2 +np.log(y*n+2)**2.5 + x*z**np.log(2) \n@jax.jit\ndef jax_test_function(x_array):\n    x = x_array[0]\n    y = x_array[1]\n    z = x_array[2]\n    n = x_array[3]\n    return x**1/2*n**2 +jnp.log(y*n+2)**2.5 + x*z**jnp.log(2)\n\n# Integration limits\na = np.array([0, 0, 0,0])\nb = np.array([np.pi, np.pi, 1, 1])\n\na_jax = jnp.array([0, 0, 0,0])\nb_jax = jnp.array([jnp.pi, jnp.pi, 1,1])\n'

In [11]:
"""
result_cb, error_cb = cubature(test_function, ndim=4, fdim=1, xmin=np.array(a), xmax=np.array(b),
                               relerr=1e-16)
print(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')
"""

"\nresult_cb, error_cb = cubature(test_function, ndim=4, fdim=1, xmin=np.array(a), xmax=np.array(b),\n                               relerr=1e-16)\nprint(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')\n"

In [12]:
"""
# Uso de numpy_cubature
minpts = 10
maxpts = 10000
rel_tol = 1e-5
ndim = 4
#finest, relerr, minpts, ifail, wrkstr, funcls = numpy_cubature(a,b, minpts, maxpts, test_function, rel_tol)
finest, relerr, minpts, ifail, wrkstr, funcls = jax_cubature(ndim=ndim,a=a_jax,b=b_jax, minpts=minpts, maxpts=maxpts, functn=jax_test_function, rel_tol=rel_tol)

print(f'Result with numpy_cubature: {finest}, Estimated error: {relerr}, ifail : {ifail}')
print(funcls)
"""

"\n# Uso de numpy_cubature\nminpts = 10\nmaxpts = 10000\nrel_tol = 1e-5\nndim = 4\n#finest, relerr, minpts, ifail, wrkstr, funcls = numpy_cubature(a,b, minpts, maxpts, test_function, rel_tol)\nfinest, relerr, minpts, ifail, wrkstr, funcls = jax_cubature(ndim=ndim,a=a_jax,b=b_jax, minpts=minpts, maxpts=maxpts, functn=jax_test_function, rel_tol=rel_tol)\n\nprint(f'Result with numpy_cubature: {finest}, Estimated error: {relerr}, ifail : {ifail}')\nprint(funcls)\n"