In [1]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as linalg
from jax import vmap
jax.config.update("jax_enable_x64", True)

In [21]:
class GaussLegendrePiecewise():
    
    def __init__(self, npts):
        """ The cartesian Gauss-Lengendre quadrature information on the reference 1d interval [-1,1].
            INPUT:
                npts:  the number of quadrature points, where (n+1) points gets a 
                        (2*n+1) algebraic precision.
            Device: cuda/cpu
        """
        index = npts - 1
        
        if index == 0:
            self.quadpts = jnp.array([[0.]], dtype=jnp.float64)
            self.weights = jnp.array([[2.]], dtype=jnp.float64)

        else:
            h1 = jnp.linspace(0,index,index+1).astype(jnp.float64)
            h2 = jnp.linspace(0,index,index+1).astype(jnp.float64) * 2

            J = 2*(h1[1:index+1]**2) / (h2[0:index]+2) * \
                jnp.sqrt(1/(h2[0:index]+1)/(h2[0:index]+3))
            J = jnp.diag(J,1) + jnp.diag(J,-1)
            D, V = linalg.eig(J)

            self.quadpts = D.real
            self.weights = (2*V[0,:]**2).real
            self.quadpts = self.quadpts.reshape(D.shape[0],1)
            self.weights = self.weights.reshape(D.shape[0],1) / 2
            
            
    def _seperate(self, pts, weis, K):
        division = pts.shape[0] / K
        partition = [0]+[round(division * (i + 1)) for i in range(K)]
        _pts = [pts[lb:rb,...] for lb, rb in zip(partition[:-1], partition[1:])]
        _weis = [weis[lb:rb,...] for lb, rb in zip(partition[:-1], partition[1:])]
        return _pts, _weis
        
            
    def interval_quadpts(self, interval, h, K=1):
        """ The Gauss-Lengendre quadrature information on a discretized mesh of 1d interval [a,b].
            Usually the mesh is uniform.
            INPUT:
                interval: jnp.array object
                       h: jnp.array object, mesh size 
            OUTPUT: integrator handle, with
                 quadpts: npts-by-1
                 weights: npts-by-1
                       h:  shape=[1] 
            Examples
            -------
            interval = jnp.array([[0, 1]], dtype=jnp.float64)
            h = jnp.array([1/100], dtype=jnp.float64)
        """
        N = (interval[0][1] - interval[0][0])/h[0] + 1
        N = int(N)
        xp = jnp.linspace(interval[0][0], interval[0][1], N).reshape(1,N)
        xp_l = xp[0][0:-1].reshape(1,N-1)
        xp_r = xp[0][1:].reshape(1,N-1)
        quadpts = (self.quadpts*h + xp_l + xp_r) / 2
        weights = jnp.tile(self.weights, quadpts.shape[1])
        quadpts = quadpts.flatten().reshape(-1,1)
        weights = weights.flatten().reshape(-1,1)
        quad_info = self._seperate(quadpts, weights, K)
        
        def integrator(f):
            int_val = 0
            area = jnp.prod(h)
            _pts, _weis = quad_info
            for pts, weis in zip(_pts, _weis):
                f_val = f(pts)
                size = [len(weis)]
                for i in range(len(f_val.shape)-1):
                    size.append(1)
                wei = jnp.reshape(weis, size)
                f_val *= wei * area
                int_val += jnp.sum(f_val, axis=0)
            return int_val
        
        return integrator
    
    def rectangle_quadpts(self, rectangle, h, K=1):
        """ The Gauss-Lengendre quadrature information on a discretized mesh of 2d rectangle [a,b]*[c,d].
            Usually the mesh is uniform.
            INPUT:
                interval: np.array object
                       h: np.array object, mesh sizes
            OUTPUT:
                 quadpts: npts-by-2
                 weights: npts-by-1
                       h:  shape=[2]  
            Examples
            -------
            rectangle = np.array([[0, 1], [0, 1]], dtype=np.float64)
            h = np.array([0.01, 0.01], dtype=np.float64)
        """

        Nx = (rectangle[0][1] - rectangle[0][0])/h[0] + 1
        Ny = (rectangle[1][1] - rectangle[1][0])/h[1] + 1
        Nx = int(Nx)
        Ny = int(Ny)

        xp = jnp.linspace(rectangle[0][0], rectangle[0][1], Nx).reshape(1,Nx)
        yp = jnp.linspace(rectangle[1][0], rectangle[1][1], Ny).reshape(1,Ny)
        xp_l = xp[0][0:-1].reshape(1,Nx-1)
        yp_l = yp[0][0:-1].reshape(1,Ny-1)
        xp_r = xp[0][1:].reshape(1,Nx-1)
        yp_r = yp[0][1:].reshape(1,Ny-1)
        xp = (self.quadpts*h[0] + xp_l + xp_r) / 2
        yp = (self.quadpts*h[1] + yp_l + yp_r) / 2

        xpt, ypt = jnp.meshgrid(xp.flatten(), yp.flatten())
        xpt = xpt.flatten().reshape(-1,1)
        ypt = ypt.flatten().reshape(-1,1)
        quadpts = jnp.concatenate((xpt,ypt), axis=1)

        weights_x = jnp.tile(self.weights, xp.shape[1])
        weights_y = jnp.tile(self.weights, yp.shape[1])
        weights_x, weights_y = jnp.meshgrid(weights_x.flatten(), weights_y.flatten())

        weights_x = weights_x.flatten().reshape(-1,1)
        weights_y = weights_y.flatten().reshape(-1,1)
        weights = weights_x * weights_y

        print(quadpts.shape)
        print(weights.shape)
        
        quad_info = self._seperate(quadpts, weights, K)

        def integrator(f):
            int_val = 0
            area = jnp.prod(h)
            _pts, _weis = quad_info
            for pts, weis in zip(_pts, _weis):
                f_val = f(pts)
                size = [len(weis)]
                for i in range(len(f_val.shape)-1):
                    size.append(1)
                wei = jnp.reshape(weis, size)
                f_val *= wei * area
                int_val += jnp.sum(f_val, axis=0)
            return int_val

        return integrator

In [22]:
h = jnp.array([1/20])
interval = jnp.array([[-1.,1.]])
quad_rule = GaussLegendrePiecewise(npts=2)
test_quadrature_1d = quad_rule.interval_quadpts(interval, h, K=41)

In [23]:
target = lambda x: jnp.reshape(jnp.cos(jnp.pi * x)**2, ())
v_target = vmap(target, (0))
I1 = test_quadrature_1d(v_target)
print(I1)

0.9999999999999999


In [24]:
h = jnp.array([1/100])
rectangle = jnp.array([[-1.,1.],[-1.,1.]])
test_quadrature_2d = quad_rule.rectangle_quadpts(rectangle, h, K=100)

In [27]:
target = lambda p: jnp.reshape(jnp.cos(jnp.pi * p[...,0]) * jnp.cos(jnp.pi * p[...,1]), ())
v_target = vmap(target, (0))