In [1]:
import numpy as onp
import jax.numpy as np
import matplotlib.pyplot as plt
import jax

In [2]:
class Quadrature:
    def __init__(self, dim, p, f):
        self.dim = dim
        self.p = p
        self.f = f


    def rule(self):
        if self.p<= 1:
            self.n = 1
            self.x = [0.]
            self.w = [2.]
        elif self.p<= 3:
            self.n = 2
            self.x = [1/np.sqrt(3), -1/np.sqrt(3)]
            self.w = [1.,1.]
        elif self.p<= 5:
            self.n = 3
            self.x = [np.sqrt(0.6), 0, -np.sqrt(0.6)]
            self.w = [5/9, 8/9, 5/9]



    def integrate(self):
        self.rule()
        gauss_quadrature = 0

        if self.dim == 1:
            for i in range(self.n): 
                gauss_quadrature += self.w[i]*self.f(self.x[i])
            print(gauss_quadrature)


        elif self.dim == 2:
            for i in range(self.n):
                for j in range(self.n):
                    gauss_quadrature += self.w[i]* self.w[j]*self.f(self.x[i],self.x[j])
        

        else: 
            raise NotImplementedError(f"Integration for dim={self.dim} not implemented.")
        return gauss_quadrature
        
        

In [3]:
def shape_functions_bilinear(xi, eta):
    N1 = 1 / 4 * (1 - xi) * (1 - eta)
    N2 = 1 / 4 * (1 + xi) * (1 - eta)
    N3 = 1 / 4 * (1 + xi) * (1 + eta)
    N4 = 1 / 4 * (1 - xi) * (1 + eta)
    return np.array([N1, N2, N3, N4])

# Wrap inputs in a single vector for differentiation
def shape_fn_wrapped(xi_eta):
    xi, eta = xi_eta
    return shape_functions_bilinear(xi, eta)


# Compute the Jacobian: each row is dN_i/d[xi, eta]
shape_fn_jacobian = jax.jacfwd(shape_fn_wrapped)  # or jax.jacrev
# shape_functions_grads_bilinear = jax.grad(shape_functions_bilinear)

xi = 0.25
eta = 0.3
T = 2

shape_functions_bilinear(xi, eta)
grads = shape_fn_jacobian(np.array([xi, eta])) # Shape (4, 2): each row is [dN_i/dxi, dN_i/deta]
print(grads)

# def get_integrand(xi, eta, X, Y, T):
#     # del_N_del_X
#     grads = shape_fn_jacobian(np.array([xi, eta])) # Shape (4, 2): each row is [dN_i/dxi, dN_i/deta]
#     del_y_del_eta = (grads[:,1].reshape(-1,1).T @ Y).item()
#     del_x_del_eta = (grads[:,1].reshape(-1,1).T @ X).item()
#     del_y_del_xi = (grads[:,0].reshape(-1,1).T @ Y).item()
#     del_x_del_xi = (grads[:,0].reshape(-1,1).T @ X).item()
#     J = np.array([[del_x_del_xi, del_x_del_eta], [del_y_del_xi,del_y_del_eta]])
#     det_J = np.linalg.det(J)




#     del_ksi_del_x = 1 / det_J * del_y_del_eta
#     del_eta_del_x = -1 / det_J * del_y_del_xi

#     del_ksi_del_y = -1 / det_J * del_x_del_eta
#     del_eta_del_y = 1 / det_J * del_x_del_xi

#     del_N_del_x = (grads[:,0].reshape(-1,1)) * del_ksi_del_x + (grads[:,1].reshape(-1,1)) * del_eta_del_x
#     del_N_del_y = (grads[:,0].reshape(-1,1)) * del_ksi_del_y + (grads[:,1].reshape(-1,1)) * del_eta_del_y

#     print(del_N_del_x)

#     # print(grads[:,0]  )
#     # print(del_ksi_del_x)
#     integrand = T * (del_N_del_x @ del_N_del_x.T + del_N_del_y @ del_N_del_y.T) 



#     return integrand
import jax.numpy as jnp

def get_integrand(xi, eta, X, Y, T):
    grads = shape_fn_jacobian(jnp.array([xi, eta]))  # (4, 2): rows = dN/dξ, dN/dη

    # Geometry Jacobian J = d(x,y)/d(ξ,η)
    dN_dxi = grads[:, 0].reshape(-1, 1)   # (4,1)
    dN_deta = grads[:, 1].reshape(-1, 1)  # (4,1)

    dx_dxi  = (dN_dxi.T @ X)[0, 0]
    dx_deta = (dN_deta.T @ X)[0, 0]
    dy_dxi  = (dN_dxi.T @ Y)[0, 0]
    dy_deta = (dN_deta.T @ Y)[0, 0]

    J = jnp.array([[dx_dxi, dx_deta],
                   [dy_dxi, dy_deta]])  # (2,2)
    
    det_J = jnp.linalg.det(J)
    inv_J = jnp.linalg.inv(J)  # [[dξ/dx, dξ/dy], [dη/dx, dη/dy]]

    # Chain rule: dN/dx = dN/dξ * dξ/dx + dN/dη * dη/dx
    # Or: [dN/dx, dN/dy] = [dN/dξ, dN/dη] @ inv(J)
    dN_dref = grads  # (4,2)
    dN_dphys = dN_dref @ inv_J  # (4,2)

    dN_dx = dN_dphys[:, 0].reshape(-1, 1)  # (4,1)
    dN_dy = dN_dphys[:, 1].reshape(-1, 1)  # (4,1)

    print(f"helele {dN_dx}")

    # Final integrand
    integrand = T * (dN_dx @ dN_dx.T + dN_dy @ dN_dy.T) * det_J

    return integrand




X = np.array([0,1,1,0]).reshape(-1,1)
Y = np.array([0,0,1,1]).reshape(-1,1)

import time
time_start = time.time()
val = get_integrand(xi, eta, X, Y, T)
time_end = time.time()
time_total = time_end - time_start
print(f'Total time {time_total}')
print(val)

[[-0.175  -0.1875]
 [ 0.175  -0.3125]
 [ 0.325   0.3125]
 [-0.325   0.1875]]
helele [[-0.35]
 [ 0.35]
 [ 0.65]
 [-0.65]]
Total time 0.7084929943084717
[[ 0.1315625   0.0559375  -0.2309375   0.0434375 ]
 [ 0.0559375   0.2565625  -0.0815625  -0.2309375 ]
 [-0.2309375  -0.0815625   0.40656248 -0.09406248]
 [ 0.0434375  -0.2309375  -0.09406248  0.28156248]]


In [4]:
def shape_functions_bilinear(xi, eta):
    N1 = 1 / 4 * (1 - xi) * (1 - eta)
    N2 = 1 / 4 * (1 + xi) * (1 - eta)
    N3 = 1 / 4 * (1 + xi) * (1 + eta)
    N4 = 1 / 4 * (1 - xi) * (1 + eta)
    return np.array([N1, N2, N3, N4])

# Wrap inputs in a single vector for differentiation
def shape_fn_wrapped(xi_eta):
    xi, eta = xi_eta
    return shape_functions_bilinear(xi, eta)


# Compute the Jacobian: each row is dN_i/d[xi, eta]
shape_fn_jacobian = jax.jacfwd(shape_fn_wrapped)  # or jax.jacrev
# shape_functions_grads_bilinear = jax.grad(shape_functions_bilinear)


X = np.array([0,1,1,0]).reshape(-1,1)
Y = np.array([0,0,1,1]).reshape(-1,1)

shape_functions_bilinear(xi, eta)
grads = shape_fn_jacobian(np.array([xi, eta])) # Shape (4, 2): each row is [dN_i/dxi, dN_i/deta]
print(grads)


def integrand_modularized(xi, eta, X, Y, T):
    grads = shape_fn_jacobian(np.array([xi, eta]))  # (4, 2): rows = dN/dξ, dN/dη
    physical_points = np.concatenate([X, Y], axis=1).T
    J =  physical_points @ grads

    physical_shape_grads = grads @ np.linalg.inv(J)
    dN_dx = physical_shape_grads[:, 0].reshape(-1, 1)  # (4,1)
    dN_dy = physical_shape_grads[:, 1].reshape(-1, 1)  # (4,1)

   
    integrand = T * (dN_dx @ dN_dx.T + dN_dy @ dN_dy.T) * np.linalg.det(J)
    return integrand

print(integrand_modularized(xi,eta,X,Y,T)-val)


## TODO: Test this
dim = 2
p = 1    

def make_integrand_wrapper(X, Y, T):
    def f(xi, eta):
        return integrand_modularized(xi, eta, X, Y, T)
    return f

Q = Quadrature(2, 2, make_integrand_wrapper(X, Y, T))
ke = Q.integrate()


# element_stifness = Quadrature(dim, p,integrand)
# result = element_stifness.integrate()



[[-0.175  -0.1875]
 [ 0.175  -0.3125]
 [ 0.325   0.3125]
 [-0.325   0.1875]]
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]


In [None]:
import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

# Nodal coordinates (1D bar)
nodes = jnp.array([0.0, 0.33, 0.66, 1.0])
num_nodes = nodes.shape[0]
dof_per_node = 1

# Connectivity: each element has 2 nodes
elements = jnp.array([
    [0, 1],
    [1, 2],
    [2, 3]
])
num_elements = elements.shape[0]

# Total DOFs
num_dofs = num_nodes * dof_per_node

# --------------------------------------
# Step 1: Element stiffness function
# --------------------------------------
def element_stiffness(Xe):
    """Returns local stiffness matrix for a 1D linear element"""
    length = Xe[1] - Xe[0]
    k = 1.0  # constant stiffness
    Ke = k / length * jnp.array([[1.0, -1.0],
                                 [-1.0, 1.0]])
    return Ke

# --------------------------------------
# Step 2: DOF map (global DOF indices per element)
# --------------------------------------
def get_element_dof_map(element_nodes, dof_per_node):
    return jnp.reshape(element_nodes[:, None] * dof_per_node + jnp.arange(dof_per_node), -1)

element_dof_map = jax.vmap(get_element_dof_map, in_axes=(0, None))(elements, dof_per_node)

# shape (num_elements, 2)

# --------------------------------------
# Step 3: Compute all element stiffness matrices
# --------------------------------------
X_elems = nodes[elements]  # shape (num_elements, 2)
Ke_all = jax.vmap(element_stiffness)(X_elems)  # shape (num_elements, 2, 2)

# --------------------------------------
# Step 4: Assemble global stiffness matrix (COO format)
# --------------------------------------
nel_dof = 2

# Compute global row and column indices
row_idx = jnp.repeat(element_dof_map, nel_dof, axis=1)  # shape (num_elements, 4)
col_idx = jnp.tile(element_dof_map, (1, nel_dof))       # shape (num_elements, 4)

data = Ke_all.reshape(-1)          # (num_elements * 4,)
rows = row_idx.reshape(-1)         # (num_elements * 4,)
cols = col_idx.reshape(-1)         # (num_elements * 4,)
indices = jnp.stack([rows, cols], axis=1)

# Construct sparse global matrix
K_global = BCOO((data, indices), shape=(num_dofs, num_dofs)).todense()

# Print for inspection
print("Global stiffness matrix:\n", K_global)


Global stiffness matrix:
 [[ 3.0303028 -3.0303028  0.         0.       ]
 [-3.0303028  6.0606055 -3.0303028  0.       ]
 [ 0.        -3.0303028  5.9714794 -2.9411767]
 [ 0.         0.        -2.9411767  2.9411767]]


In [7]:
element_dof_map

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

In [6]:
X_elems

Array([[0.  , 0.33],
       [0.33, 0.66],
       [0.66, 1.  ]], dtype=float32)