# new INCF


INCF integrated neighbouring contribution functions have been updated in Aug 10th (2021) because the previous one is computationally too heavy.

In [1]:
import jax.numpy as jnp
import numpy as np
from jax import jit
from jax import vmap
from jax.lax import scan
from jax.ops import index_add
from jax.ops import index as joi

## OLD code

In [2]:
@jit
def Xncf(i,x,xv):
    """neighbouring contribution function for index i.                          
                                                                                
    Args:                                                                       
        i: index                                                                
        x: x value                                                              
        xv: x-grid                                                              
                                                                                
    Returns:                                                                    
        neighbouring contribution function of x to the i-th component of the ar\
ray with the same dimension as xv.                                              
                                                                                
    """
    indarr=jnp.arange(len(xv))
    pos = jnp.interp(x,xv,indarr)
    index = (pos).astype(int)
    cont = pos-index
    f=jnp.where(index==i,1.0-cont,0.0)
    g=jnp.where(index+1==i,cont,0.0)
    return f+g

In [22]:
from jax.lax import scan
from jax import jit,vmap

@jit
def inc2D_old(w, x,y,xv,yv):

    Ngx=len(xv)
    Ngy=len(yv)
    indarrx=jnp.arange(Ngx)
    indarry=jnp.arange(Ngy)
    vcl=vmap(Xncf,(0,None,None),0)
    fx=vcl(indarrx,x,xv) # Ngx x N  memory
    fy=vcl(indarry,y,yv) # Ngy x N memory
    #jnp.sum(fx[:,None]*fy[None,:],axis=2) Ngx x Ngy x N -> huge memory 
    
    fxy=jnp.vstack([fx,fy,w]).T
    def fsum(x,arr):
        null=0.0
        fx=arr[0:Ngx]
        fy=arr[Ngx:Ngx+Ngy]
        w=arr[Ngx+Ngy]
        val=x+w*fx[:,None]*fy[None,:]
        return val, null
    
    init0=jnp.zeros((Ngx,Ngy))
    val,null=scan(fsum,init0,fxy)
    return val
    

In [23]:
w=jnp.array([1.0,0.7,0.4])
x=jnp.array([1.2,2.3,2.7])
xv=jnp.linspace(0,4,3)

y=jnp.array([0.5,2.7,0.1])
yv=jnp.linspace(0,4,4)

In [24]:
#old code
b=inc2D_old(w,x,y,xv,yv)

## new code (Aug 10. after the 2nd vaccineization of Covid-19)

In [25]:
def getix(x,xv):
    indarr=jnp.arange(len(xv))
    pos = jnp.interp(x,xv,indarr)
    index = (pos).astype(int)
    cont = (pos-index)
    return cont,index

In [26]:
@jit
def inc2D(w,x,y,xv,yv):
    cx,ix=getix(x,xv)
    cy,iy=getix(y,yv)
    a=jnp.zeros((len(xv),len(yv)))
    a=index_add(a,joi[ix,iy],w*(1-cx)*(1-cy))
    a=index_add(a,joi[ix,iy+1],w*(1-cx)*cy)
    a=index_add(a,joi[ix+1,iy],w*cx*(1-cy))
    a=index_add(a,joi[ix+1,iy+1],w*cx*cy)
    return a


In [28]:
a=inc2D(w,x,y,xv,yv)
a

DeviceArray([[0.24999999, 0.14999999, 0.        , 0.        ],
             [0.6155    , 0.24450001, 0.580125  , 0.01487506],
             [0.12950002, 0.0105    , 0.10237497, 0.00262501]],            dtype=float32)

In [31]:
b=inc2D_old(w,x,y,xv,yv)
b

DeviceArray([[0.24999999, 0.14999999, 0.        , 0.        ],
             [0.6155    , 0.24450001, 0.580125  , 0.01487506],
             [0.12950002, 0.0105    , 0.10237497, 0.00262501]],            dtype=float32)

In [32]:
np.sum((a-b)**2)

DeviceArray(0., dtype=float32)

In [33]:
%timeit a=inc2D(w,x,y,xv,yv)

227 µs ± 48.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [34]:
%timeit b=inc2D_old(w,x,y,xv,yv)

285 µs ± 8.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


For a few lines, both codes gives simliar computational time. However when the number of lines is large ...

In [35]:
Nl=10000
w=jnp.array(np.random.rand(Nl))
x=jnp.array(np.random.rand(Nl)*3)
xv=jnp.linspace(0,4,3)

y=jnp.array(np.random.rand(Nl)*4)
yv=jnp.linspace(0,4,4)

In [36]:
%timeit a=inc2D(w,x,y,xv,yv)

207 µs ± 29.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [37]:
%timeit b=inc2D_old(w,x,y,xv,yv)

80.8 ms ± 3.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


siginificant difference!