In [202]:
import numpy as np

In [203]:
def numpy_cubature(a, b, minpts, maxpts, functn, rel_tol):
    
    ndim = len(a)

    if ndim < 2 or ndim > 100:
        raise ValueError("ndim must be between 2 and 100")
    
    twondim = 2.0**ndim
    rgnstr  = 2*ndim +3 
    divaxo  = 0

    #Initialize the cubature
    lambda5 = 9.0/19.0
    if ndim<=15:
        rulcls = 2**ndim + 2*ndim*ndim + 2*ndim +1
        lambda4 = 9.0/10.0
        lambda2 = 9.0/70.0
        weight5 = 1.0/(3.0*lambda5)**3 /twondim
    else:
        rulcls = 1 + (ndim*(12+(ndim-1)*(6+(ndim-2)*4)))/3
        ratio = (ndim-2)/9.0
        lambda4 = (1.0/5.0 -ratio)/(1.0/3.0 -ratio/lambda5)
        ratio = (1.0 -lambda4/lambda5)*(ndim-1)*ratio/6.0
        lambda2 = (1.0/7.0 -lambda4/5.0 -ratio)/(1.0/5.0 -lambda4/3.0 -ratio/lambda5)
        weight5 = 1.0/(6.0*lambda5)**3

    if rulcls > maxpts:
        raise ValueError("maxpts too small")    

    weight4 = (1.0/15.0 -lambda5/9.0)/(4.0*(lambda4-lambda5)*lambda4**2)
    weight3 = (1.0/7.0 -(lambda5+lambda2)/5.0 +lambda5*lambda2/3.0)/(2.0*lambda4*(lambda4-lambda5)*(lambda4-lambda2)) -2.0*(ndim-1)*weight4
    weight2 = (1.0/7.0 -(lambda5+lambda4)/5.0 +lambda5*lambda4/3.0)/(2.0*lambda2*(lambda2-lambda5)*(lambda2-lambda4)) 

    if ndim<=15:
        weight1 = 1.0 -2.0*ndim*(weight2+weight3+(ndim-1)*weight4)-twondim*weight5
    else:
        weight1 = 1.0 -ndim*(weight2+weight3+(ndim-1)*(weight4+2.0*(ndim-2)*weight5/3.0))
    
    weight4p = 1.0/(6.0*lambda4)**2
    weight3p = (1.0/5.0 -lambda2/3.0)/(2.0*lambda4*(lambda4-lambda2)) -2.0*(ndim-1)*weight4p
    weight2p = (1.0/5.0 -lambda4/3.0)/(2.0*lambda2*(lambda2-lambda4))
    weight1p = 1.0 -2.0*ndim*(weight2p+weight3p+(ndim-1)*weight4p)

    ratio = lambda2/lambda4

    lambda5 = np.sqrt(lambda5)
    lambda4 = np.sqrt(lambda4)
    lambda2 = np.sqrt(lambda2)

    lambdas = np.array([lambda2, lambda4, lambda5])
    weights = np.array([weight1, weight2, weight3, weight4, weight5])
    weightsp = np.array([weight1p, weight2p, weight3p, weight4p])


    width  = np.zeros(100)
    center = np.zeros(100)
    z      = np.zeros(100)
    lenwrk = (2*ndim+3)*(1+maxpts/rulcls)/2
    lenwrk = np.int32(lenwrk)
    wrkstr = np.zeros(lenwrk+1)
    widthl = np.zeros(100)
    funcls = 0
    subtmp = 0
    ifail  = 3
    finest = 0.0
    
    if minpts<0:
        subrgns = np.int32(wrkstr[lenwrk-1])
        divflg = 0
        wrkstr[lenwrk] = wrkstr[lenwrk] - wrkstr[subrgns]
        finest = finest - wrkstr[subrgns-1]
        divaxo = np.int32(wrkstr[subrgns-2])

        for j in range(ndim):
            subtmp = subrgns-2*(j+1)
            center[j] = wrkstr[subtmp]
            width[j] = wrkstr[subtmp]
        width[divaxo] = width[divaxo]/2.0
        center[divaxo] = center[divaxo]-width[divaxo]
    else:
        for j in range(ndim):
            width[j] = (b[j]-a[j])/2.0
            center[j] = a[j] + width[j]
    
        finest = 0.0
        subrgn = rgnstr
        sbrgns = rgnstr
        divflg = 1
    
    tmpflg = 0
    while True:
        while divflg == 0 or tmpflg==0:
            if tmpflg==0:
                tmpflg = 1

            #Start Basic Rule
            rgnvol = twondim
            for j in range(ndim):
                rgnvol = rgnvol*width[j]
                z[j] = center[j]

            sum1 = functn(z)
            #Compute the symetric sums of functn(lambda2,0,0,..0) and functn(lambda4,0,0,..0), and 
            #maximum fourth difference
            difmax = -1.0
            sum2 = 0.0
            sum3 = 0.0
            for j in range(ndim):
                z[j]=center[j]-lambdas[0]*width[j]
                f1 = functn(z)
                z[j]=center[j]+lambdas[0]*width[j]
                f2 = functn(z)
                widthl[j] = lambdas[1]*width[j]
                z[j]=center[j]-widthl[j]
                f3 = functn(z)
                z[j]=center[j]+widthl[j]
                f4 = functn(z)
                sum2 = sum2 + f1 + f2
                sum3 = sum3 + f3 + f4
                df1 = f1+f2-2.0*sum1
                df2 = f3+f4-2.0*sum1
                dif = np.fabs(df1-ratio*df2)
                if difmax<dif:
                    difmax = dif
                    divaxn = j
                z[j] = center[j]


            if sum1 == sum1+difmax/8.0:
                divaxn = np.mod(divaxo,ndim)+1
            
            #Compute the symetric sums of functn(lambda4,lambda4,0,..0)
            sum4 = 0.0
            for j in range(1,ndim):
                for k in range(j,ndim):
                    for l in range(2):
                        widthl[j-1]=-widthl[j-1]
                        z[j-1]=center[j-1]+widthl[j-1]

                        for m in range(2):
                            widthl[k]=-widthl[k]
                            z[k]=center[k]+widthl[k]
                            f1 = functn(z)
                            sum4 = sum4 + f1
                    
                    z[k] = center[k]
                z[j-1] = center[j-1]
            
            #If NDIM<16 compute symmetric sum of functn(lambda5,lambda5,lambda5,..lambda5)
            sum5 = 0.0
            if ndim<=15:
                for j in range(ndim):
                    widthl[j] = -lambdas[2]*width[j]
                    z[j] = center[j]+widthl[j]
                
                while True:
                    f1 = functn(z)
                    sum5 = sum5 + f1
                    
                    for j in range(ndim):
                        widthl[j] = -widthl[j]
                        z[j] = center[j]+widthl[j]
                        if widthl[j]>0.0:
                            break
                    else:
                        if j==ndim-1:
                            break
                        else:
                            continue
            else:
                #If 15<NDIM compute symmetric sum of functn(lambda5,lambda5,lambda5,0,0...0)
                for j in range(ndim):
                    widthl[j] = lambdas[2]*width[j]
                for i in range(2,ndim):
                    for j in range(i,ndim):
                        for k in range(j,ndim):
                            for l in range(2):
                                widthl[i-2] = -widthl[i-2]
                                z[i-2] = center[i-2]+widthl[i-2]
                                for m in range(2):
                                    widthl[j-1]=-widthl[j-1]
                                    z[j-1]=center[j-1]+widthl[j-1]
                                    for n in range(2):
                                        widthl[k]=-widthl[k]
                                        z[k]=center[k]+widthl[k]
                                        f1 = functn(z)
                                        sum5 = sum5 + f1
                            
                            z[k] = center[k]
                        z[j-1] = center[j-1]
                    z[i-2] = center[i-2]
            
            #Compute fifth and seventh degree rules and error.
            rgncmp = rgnvol *(weightsp[0]*sum1 + weightsp[1]*sum2 + weightsp[2]*sum3 + weightsp[3]*sum4)
            rgnval = rgnvol *(weights[0]*sum1 + weights[1]*sum2 + weights[2]*sum3 + weights[3]*sum4 + weights[4]*sum5)
            rgnerr = np.fabs(rgnval-rgncmp)

            finest = finest+rgnval
            wrkstr[lenwrk] = wrkstr[lenwrk]+rgnerr
            funcls = funcls+rulcls
            #End basic rule
            print(finest)

            #When divflg=0, start at top of list and move down
            #list tree to find correct position for results from 
            #first half of recently divided subregion
            if  divflg !=1:
                while True:
                    subtmp = 2*subrgn
                    if subtmp>sbrgns:
                        break
                    
                    if subtmp!=sbrgns:
                        sbtmpp = subtmp+rgnstr
                        if wrkstr[subtmp]<wrkstr[sbtmpp]:
                            subtmp = sbtmpp

                    if rgnerr>=wrkstr[subtmp]:
                        break  
                            
                    for k in range(rgnstr):
                        wrkstr[subrgn-k+1] = wrkstr[subtmp-k+1]
                    subrgn = subtmp

            #When divflg=1, start at bottom right branch and move
            #up list tree to find correct position for results from
            #second half of recently divided subregion
            else: 
                while True:
                    subtmp = np.int32((subrgn/(2*rgnstr))*rgnstr)
                    subtmp = (subrgn//(2*rgnstr))*rgnstr
                    if subtmp<rgnstr:
                        break               
                    if rgnerr <= wrkstr[subtmp]:
                        break
                        
                    for k in range(rgnstr):
                        wrkstr[subrgn-k+1] = wrkstr[subtmp-k+1]
                        #wrkstr[subrgn-k] = wrkstr[subtmp-k]
                    subrgn = subtmp
            
            #Store results of basic rule in correct position in list
            wrkstr[subrgn] = rgnerr
            wrkstr[subrgn-1] = rgnval
            wrkstr[subrgn-2] = divaxn
            for j in range(ndim):
                subtmp = subrgn-2*(j+1)
                wrkstr[subtmp+1] = center[j]
                wrkstr[subtmp] = width[j]

            #When divflg=0, prepare to call basic rule on second half of subregion
            print("divflg: ", divflg)
            if divflg != 1:
                center[divaxo] = center[divaxo] + 2.0*width[divaxo]
                sbrgns = sbrgns + rgnstr
                subrgn = sbrgns
                divflg = 1
                tmpflg = 0
            else:
                break

        #End ordering and storage of basic rule results.
        #Check the convergence for possible termination.
        print('Now I am here')
        relerr = 1.0
        if wrkstr[lenwrk] <= 0.0:
            wrkstr[lenwrk] = 0.0

        if np.fabs(finest) != 0.0:
            relerr = wrkstr[lenwrk]/np.fabs(finest)

        if 1.0 < relerr:
            relerr = 1.0
        
        if lenwrk<(sbrgns+rgnstr+2):
            ifail=2

        if maxpts<(funcls+2*rulcls):
            ifail=1

        if relerr<rel_tol and minpts<=funcls:
            ifail=0
        
        if ifail<3:
            minpts = funcls
            wrkstr[lenwrk-1] = sbrgns
            return finest, relerr, minpts, ifail, wrkstr, funcls
        
        #Prepare a new call to basic rule on each half of the subregion with the largest error
        divflg = 0
        subrgn = rgnstr
        wrkstr[lenwrk] = wrkstr[lenwrk] - wrkstr[subrgn]
        finest = finest - wrkstr[subrgn-1]
        divaxo = np.int32(wrkstr[subrgn-2])
        for j in range(ndim):
            subtmp = subrgn-2*(j+1)
            center[j] = wrkstr[subtmp+1]
            width[j] = wrkstr[subtmp]
            
        width[divaxo] = width[divaxo]/2.0
        center[divaxo] = center[divaxo]-width[divaxo]


In [204]:
from cubature import cubature
from numba import jit

@jit(nopython=True)
def test_function(x_array):
    x = x_array[0]
    y = x_array[1]
    z = x_array[2]
    return x**2 +np.log10(y+2)**2.5 + x*z**np.log(2)

# Integration limits
a = [0, 0, 0]
b = [np.pi, np.pi, 1] 

In [205]:
result_cb, error_cb = cubature(test_function, ndim=3, fdim=1, xmin=np.array(a), xmax=np.array(b),
                               relerr=1e-5)
print(f'Result with cubature: {result_cb}, Estimated error: {error_cb}')

Result with cubature: [43.8995899], Estimated error: [0.00020542]


In [206]:
# Uso de numpy_cubature
minpts = 10
maxpts = 3300
rel_tol = 1e-4
finest, relerr, minpts, ifail, wrkstr, funcls = numpy_cubature(a, b, minpts, maxpts, test_function, rel_tol)
print(f'Result with numpy_cubature: {finest}, Estimated error: {relerr}')
print(funcls)

43.899896782438006
divflg:  1
Now I am here
63.804028660648996
divflg:  0
86.22899675887393
divflg:  1
Now I am here
95.74753806135412
divflg:  0
107.08316853006028
divflg:  1
Now I am here
116.13904732490391
divflg:  0
126.98730040782954
divflg:  1
Now I am here
131.05350744116762
divflg:  0
136.75213454974056
divflg:  1
Now I am here
Result with numpy_cubature: 136.75213454974056, Estimated error: 9.757004026004413e-05
297


In [207]:
[i for i in range(1,3)]

[1, 2]

In [208]:
import jax.numpy as jnp


test_array = jnp.array([1,2,3])
print(test_array)
test_array = test_array.at[1].set(10)
print(test_array)

[1 2 3]
[ 1 10  3]


In [209]:
!pwd

  pid, fd = os.forkpty()


/home/adri/Documentos/Tesis/JAX_cubature
