In [105]:
import numpy as np
from scipy.integrate import quad, tplquad
import vtk
import pyvista as pv
from scipy.sparse import csc_matrix, coo_matrix,  diags
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt
from tqdm import tqdm

import jax 
import jax.numpy as jnp
from functools import partial


import networkx as nx


# Use 64 bit floats for jax
jax.config.update("jax_enable_x64", True)

---

# Problem description

The goal is to solve for the displacement field of a linear elastic solid under a given load. 
Instead of employing the Garlekin method of weighted residuals, we will use this time the hamiltonian principle of stationary action: 

The action functional is given by:
$$
S = \int_{t_0}^{t_f} L(q, \dot{q}) dt
$$

Where $L$ is the Lagrangian of the system, which is given by: $L = T- V$. Where $V$ is the potential energy of the system and $T$ is the kinetic energy of the system. $q$ is the generalized coordinate of the system and $\dot{q}$ is the generalized velocity of the system.

The stationary action principle states that the true motion of the system is such that the action is stationary, i.e. the variation of the action is zero:

$$
\delta S = 0
$$

By applying the calculus of variations to $\delta S$, the equations of motion for the system are derived. This leads to the Euler-Lagrange equations:

$$
\frac{d}{dt} \left( \frac{\partial L}{\partial \dot{q}} \right) - \frac{\partial L}{\partial q} = 0
$$

The first step of this approach is to derive the Lagrangian of the system. 

---

In [3]:
# Create a face class to store the information of each face

class Face: 
    def __init__(self, vertices, cell_1 = -1, cell_2 = -1):
        self.vertices   = vertices
        self.cell_1     = cell_1
        self.cell_2     = cell_2

    def add_cell(self, cell):
        if self.cell_1 == None:
            self.cell_1 = cell
        else:
            self.cell_2 = cell

    def is_boundary_face(self):
        return self.cell_2 == -1
    
    def __hash__(self):
        return tuple(sorted(self.vertices))
    
    def __eq__(self, other):
        return self.__hash__() == other.__hash__()

In [4]:
# Read the mesh and find the points and faces located on the boundary

reader = vtk.vtkUnstructuredGridReader()
reader.SetFileName("tetrahedralized_cube_high_res.vtk")
reader.Update()
u_grid = reader.GetOutput()

# Collect the positions of the nodes
n_nodes = u_grid.GetNumberOfPoints()
node_lst = [u_grid.GetPoint(i) for i in range(n_nodes)]
point_ar = np.array(node_lst)
n_vertices = point_ar.shape[0]

cell_lst = []
#Loop over the cells in the VTK mesh (not all vtk cells are tetrahedra)
for cell_id in range(u_grid.GetNumberOfCells()):

    cell = u_grid.GetCell(cell_id)
    n_cell_pts = cell.GetNumberOfPoints()

    #If the cell is a tetrahedron
    if n_cell_pts == 4:

        # Get the points of the cell
        cell_lst.append([cell.GetPointId(i) for i in range(n_cell_pts)])

# Store the faces of the mesh in this dictionary
face_dic = {}

# Loop over the tetrahedra in the mesh
for cell_id, cell_pts in enumerate(cell_lst):
    face_0 = [cell_pts[0], cell_pts[1], cell_pts[2]]
    face_1 = [cell_pts[0], cell_pts[1], cell_pts[3]]
    face_2 = [cell_pts[0], cell_pts[2], cell_pts[3]]
    face_3 = [cell_pts[1], cell_pts[2], cell_pts[3]]

    #Create the faces
    face_0 = Face(face_0, cell_id)
    face_1 = Face(face_1, cell_id)
    face_2 = Face(face_2, cell_id)
    face_3 = Face(face_3, cell_id)

    # Check if the faces are already in the dictionary
    face_0_dic = face_dic.get(face_0.__hash__(), None)
    face_1_dic = face_dic.get(face_1.__hash__(), None)
    face_2_dic = face_dic.get(face_2.__hash__(), None)
    face_3_dic = face_dic.get(face_3.__hash__(), None)

    if face_0_dic is None: face_dic[face_0.__hash__()] = face_0
    else: face_0_dic.add_cell(cell_id)

    if face_1_dic is None: face_dic[face_1.__hash__()] = face_1
    else: face_1_dic.add_cell(cell_id)
        
    if face_2_dic is None: face_dic[face_2.__hash__()] = face_2
    else: face_2_dic.add_cell(cell_id)

    if face_3_dic is None: face_dic[face_3.__hash__()] = face_3
    else: face_3_dic.add_cell(cell_id)


# The point IDs of the tetrahedral cells
cell_ar = np.array(cell_lst)
n_cells = cell_ar.shape[0]

# The point IDs of the faces
face_ar =       np.array([f.vertices for f in face_dic.values()])
n_faces = face_ar.shape[0]

# The cell sharing each face
face_cell_ar =  np.array([[f.cell_1, f.cell_2] for f in face_dic.values()])

# Store for each tetrahedron the faces that it contains
cell_face_lst = [[] for _ in range(cell_ar.shape[0])]

for face_id, (cell_1, cell_2) in enumerate(face_cell_ar):
    if cell_1 != -1: cell_face_lst[cell_1].append(face_id)
    if cell_2 != -1: cell_face_lst[cell_2].append(face_id)

cell_face_ar = np.array(cell_face_lst)

# We need to collect all the boundary points
boundary_faces_idx = np.argwhere(face_cell_ar[:, 1] == -1).flatten()
boundary_face_mask = np.zeros(n_faces, dtype=bool)
boundary_face_mask[boundary_faces_idx] = True

# Collect the points of the boundary faces
boundary_points = np.unique(face_ar[boundary_faces_idx].flatten())
boundary_point_mask = np.zeros(n_vertices, dtype=bool)
boundary_point_mask[boundary_points] = True

<br>
<br>
<br>
<br>
<br>

---

## State variables interpolation in the mesh elements

The displacement field as well as the stress and strain fields are linearly interpolated in the mesh elements (tetrahedra in this case) based on the values at the element nodes. The position inside an element can be defined based on the position of the element nodes and the barycentric coordinates of the point inside the element. This gives 
the following relationship: 

$$
X = \phi_0(X) X_0 + \phi_1(X) X_1 + \phi_2(X) X_2 + \phi_3(X) X_3
$$

Here $X$ denotes the position of a point in the reference/underformed geometry. The position of the same point in the deformed geometry is denoted by $x(X)$. The functions $\phi_i(X)$ are the barycentric coordinates of the point $X$ inside the element. This interpolation given above is also valid for any function such as the position in the deformed geometry $x(X)$, the displacement field $u(X)$, the strain field $\varepsilon(X)$, the stress field $\sigma(X)$, etc. 


We can rewrite the above equation into a matrix vector form: 

$$
X = 
\begin{bmatrix}
 \mid& \mid&  \mid& \mid  \\
 X_0&  X_1&  X_2&  X_3 \\
 \mid&  \mid&  \mid& \mid \\
\end{bmatrix}  

\begin{bmatrix}
\phi_0(X)\\
\phi_1(X)\\
\phi_2(X)\\
\phi_3(X)\\
\end{bmatrix}
$$

This system is not solvable, but with the fact that the sum of the barycentric coordinates is always equal to 1, we can rewrite the above equation as:

$$
X - X_0 = 
\underbrace{
\begin{bmatrix}
 \mid& \mid&  \mid \\
  X_1 - X_0&  X_2- X_0&  X_3- X_0 \\
 \mid&  \mid&  \mid&\\
\end{bmatrix}}_{=\;T}

\begin{bmatrix}
\phi_1(X)\\
\phi_2(X)\\
\phi_3(X)\\
\end{bmatrix}
$$

and 

$$
\phi_0(X) = 1 - \phi_1(X) - \phi_2(X) - \phi_3(X)
$$


we can therefore write that

$$
\begin{bmatrix}
\phi_1(X)\\
\phi_2(X)\\
\phi_3(X)\\
\end{bmatrix} = 

T^{-1} (X - X_0)
$$

and 

$$
\phi_0(X) = 1 - \vec{1}^T T^{-1} (X - X_0)
$$

where $\vec{1}$ is a vector of ones.

Assembling everything together, we obtain the following expression for the barycetric position of a point inside an element:


$$ 
\begin{bmatrix}
\phi_0(X)\\
\phi_1(X)\\
\phi_2(X)\\
\phi_3(X)\\
\end{bmatrix} =  


\begin{bmatrix}
1\\
0\\
0\\
0\\
\end{bmatrix} 
+
\underbrace{
\begin{bmatrix}
& & & -\vec{1}^{T} \; T^{-1} & & & \\
-&-&-&-&-&-& \\
& & & & & & \\
& & &T^{-1}& & & \\
& & & & & & \\
\end{bmatrix}}_{=D \;(4 \times 3)}


\underbrace{
(X - X_0)}_{(3 \times 1)}


$$

The position in the deformed geometry is given by:

$$
x(X) = \underbrace{
\begin{bmatrix}
 \mid& \mid&  \mid& \mid  \\
 x_0&  x_1&  x_2&  x_3 \\
 \mid&  \mid&  \mid& \mid \\
\end{bmatrix}}_{(3 \times 4)} \cdot

\begin{bmatrix}
\phi_0(X)\\
\phi_1(X)\\
\phi_2(X)\\
\phi_3(X)\\
\end{bmatrix}
$$

$$
x(X) = \vec{x_0} +

\underbrace{
\begin{bmatrix}
 \mid& \mid&  \mid& \mid  \\
 x_0&  x_1&  x_2&  x_3 \\
 \mid&  \mid&  \mid& \mid \\
\end{bmatrix}}_{(3 \times 4)} 

\underbrace{
\begin{bmatrix}
& & & -\vec{1}^{T} \; T^{-1} & & & \\
-&-&-&-&-&-& \\
& & & & & & \\
& & &T^{-1}& & & \\
& & & & & & \\
\end{bmatrix}}_{=D \; (4 \times 3)}

\underbrace{
(X - X_0)}_{(3 \times 1)}
$$

This equation is pretty useful as it allows to easily compute the deformation gradient tensor:

$$
F = \underbrace{\frac{\partial}{\partial X} x(X,t)}_{(3 \times 3)} = \underbrace{
\begin{bmatrix}
 \mid& \mid&  \mid& \mid  \\
 x_0&  x_1&  x_2&  x_3 \\
 \mid&  \mid&  \mid& \mid \\
\end{bmatrix}}_{(3 \times 4)} 

\underbrace{
\begin{bmatrix}
& & & -\vec{1}^{T} \; T^{-1} & & & \\
-&-&-&-&-&-& \\
& & & & & & \\
& & &T^{-1}& & & \\
& & & & & & \\
\end{bmatrix}}_{=D \; (4 \times 3)}


$$

---


In [5]:
# Those are the basis functions and their gradients on the reference tetrahedron

def basis_fn_1(u, v, w): return 1 - u - v - w 
def basis_fn_2(u, v, w): return u
def basis_fn_3(u, v, w): return v
def basis_fn_4(u, v, w): return w

grad_basis_fn_1 = np.array([-1, -1, -1])
grad_basis_fn_2 = np.array([ 1,  0,  0])
grad_basis_fn_3 = np.array([ 0,  1,  0])
grad_basis_fn_4 = np.array([ 0,  0,  1])

reference_basis_fn_lst =      [basis_fn_1, basis_fn_2, basis_fn_3, basis_fn_4]
reference_grad_basis_fn_lst = [grad_basis_fn_1, grad_basis_fn_2, grad_basis_fn_3, grad_basis_fn_4]
n_local_basis_fn = len(reference_basis_fn_lst)

In [6]:
# Compute the T and D matrices for all the tetrhedra in the mesh
x0_ar = point_ar[cell_ar[:, 0]]
x1_ar = point_ar[cell_ar[:, 1]]
x2_ar = point_ar[cell_ar[:, 2]]
x3_ar = point_ar[cell_ar[:, 3]]

#Compute the T matrix for all the tetrahedra
T = np.column_stack([x1_ar - x0_ar, x2_ar - x0_ar, x3_ar - x0_ar]).reshape(-1, 3, 3)
T = np.transpose(T, (0, 2, 1))

#Compute the inverse of the T matrix
T_inv = np.linalg.inv(T)

#Compute the D matrix for all the tetrahedra
D = np.hstack([-np.sum(T_inv, axis=1)[:, np.newaxis], T_inv])

# Compute the volume of each tetrahedron in the reference configuration
ref_volume_ar = np.abs(np.linalg.det(T)) / 6

In [7]:

#Make sure that the D matrix is correct
def test_D_matrix(D): 

    # Create virtual deformed point coordinates
    def_point_ar = point_ar * np.array([2, 1, 1])

    # Get the point coordinates of the cells
    def_x0_ar = def_point_ar[cell_ar[:, 0]]
    def_x1_ar = def_point_ar[cell_ar[:, 1]]
    def_x2_ar = def_point_ar[cell_ar[:, 2]]
    def_x3_ar = def_point_ar[cell_ar[:, 3]]

    # Compute the T matrix in the deformed configuration
    T_def = np.column_stack([def_x1_ar - def_x0_ar, def_x2_ar - def_x0_ar, def_x3_ar - def_x0_ar]).reshape(-1, 3, 3)
    T_def = np.transpose(T_def, (0, 2, 1))

    # Compute the correct deformation gradient tensor
    F_correct = np.matmul(T_def, T_inv)

    # Compute the deformation gradient tensor using the D matrix
    M = np.transpose(def_point_ar[cell_ar], (0, 2, 1))
    F_computed = np.matmul(M, D)

    # Check that the matrices are the same
    np.testing.assert_allclose(F_correct, F_computed, atol=1e-5)

test_D_matrix(D)

<br>
<br>
<br>
<br>
<br>

---

## Kinetic energy

The kinetic energy of the system is given by:

$$
T = \sum_{i=1}^{N} \frac{1}{2} \int_{\Omega_{i,0}} \dot{q}(t, X)^T \, \dot{q}(t, X) \; d\Omega_{i,0}
$$

Where the sum runs over all the element cells of the mesh (tetrahedra in this case). $q(X,t)$ is the generalized coordinate of the system. In this case, the generalized coordinate is the displacement field $u(X,t)$. The dot denotes the time derivative. Note that the kinetic energy of the system is computed in the reference configuration (Lagrangian formulation).

For each element, we can write the displacement of a given point as: 

$$

u(X,t) = 

\underbrace{
\begin{bmatrix}
 & & & \\
\mathrm{I}\phi_0(X) & \mathrm{I}\phi_1(X) & \mathrm{I}\phi_2(X) & \mathrm{I}\phi_3(X)\\
 & & & \\
\end{bmatrix}}_{= N(X) \; (3 \times 12)}


\underbrace{
\begin{bmatrix}
\mid \\
u_0(t) \\
\mid \\
 \\
\mid\\
u_1(t) \\
\mid \\
... \\
\end{bmatrix}}_{=u(t) \; (1 \times 12)}
$$

The kinetic energy of the system is therefore given by:

$$
\sum_{i=1}^{N} \frac{1}{2} \dot{u}(t)^{T} \underbrace{\int_{\Omega_{i,0}} N(X)^T  N(X) \, d\Omega_{i,0}}_{= \mathrm{M}_0 (12 \times 12)} \; \dot{u}(t)
$$


Where $\mathrm{M}_0$ is the local mass matrix of the system: 

$$
\mathrm{M}_0 = 
\int_{\Omega_{i,0}}
\begin{bmatrix}
\mathrm{I} \phi_0(X) \phi_0(X) & \mathrm{I} \phi_0(X) \phi_1(X) & \mathrm{I} \phi_0(X) \phi_2(X) & \mathrm{I} \phi_0(X) \phi_3(X) \\
\mathrm{I} \phi_1(X) \phi_0(X) & \mathrm{I} \phi_1(X) \phi_1(X) & \mathrm{I} \phi_1(X) \phi_2(X) & \mathrm{I} \phi_1(X) \phi_3(X) \\
\mathrm{I} \phi_2(X) \phi_0(X) & \mathrm{I} \phi_2(X) \phi_1(X) & \mathrm{I} \phi_2(X) \phi_2(X) & \mathrm{I} \phi_2(X) \phi_3(X) \\
\mathrm{I} \phi_3(X) \phi_0(X) & \mathrm{I} \phi_3(X) \phi_1(X) & \mathrm{I} \phi_3(X) \phi_2(X) & \mathrm{I} \phi_3(X) \phi_3(X) \\
\end{bmatrix} d\Omega_{i,0}
$$

This local mass matrix can be computed like done in the previous notebooks by numerically integrating the shape functions in a reference element and 
then transforming the region of integration. The global matrix can also be assembled in a similar way as done in the previous notebooks.

---

In [8]:
#Compute the mass matrix for a reference tetrahedron

#Integrate the product of the basis functions over the reference tetrhedron
M_reference = np.zeros((n_local_basis_fn, n_local_basis_fn))

for i in range(n_local_basis_fn):
    for j in range(i, n_local_basis_fn):
        
        M_reference[i, j] = tplquad(
            lambda u, v, w: reference_basis_fn_lst[i](u, v, w) * reference_basis_fn_lst[j](u, v, w), 
            0, 1, 
            0, lambda u: 1 - u, 
            0, lambda u, v: 1 - u - v
        )[0]
        
        M_reference[j, i] = M_reference[i, j] 


# For each tetrahedron we need to compute the jacobian that maps the element coordinates to the reference coordinates
p0_ar = point_ar[cell_ar[:, 0]]
p1_ar = point_ar[cell_ar[:, 1]]
p2_ar = point_ar[cell_ar[:, 2]]
p3_ar = point_ar[cell_ar[:, 3]]
jacobian_ar = np.column_stack([p0_ar - p3_ar, p1_ar - p3_ar, p2_ar - p3_ar]).reshape(-1, 3, 3)

#Compute the determinant of each Jacobian matrix
det_jacobian_ar = np.abs(np.linalg.det(jacobian_ar))
local_mass_matrix_ar  = np.zeros((n_cells, n_local_basis_fn, n_local_basis_fn))

#Compute the local mass matrices
for i in range(n_local_basis_fn):
    for j in range(i, n_local_basis_fn):
 
        #body_force matrix
        local_mass_matrix_ar[:, i, j] = det_jacobian_ar * M_reference[i, j]
        local_mass_matrix_ar[:, j, i] = local_mass_matrix_ar[:, i, j] 


# we assume that the density of the material is constant at 1 kg/m^3 
density_ar = np.ones(n_cells)
local_mass_matrix_ar = local_mass_matrix_ar * density_ar[:, np.newaxis, np.newaxis]

#Now we need to assemble the global mass matrix
row_id = np.repeat(cell_ar, n_local_basis_fn).flatten()
col_id = np.tile(cell_ar,   n_local_basis_fn).flatten()
mass_data = local_mass_matrix_ar.flatten()

# Since we are solving for a 3D displacement field, and not just one scalar field, we need to expand the local mass matrix
# to account for the three components of the displacement field
row_id = np.column_stack([row_id*3, row_id*3 + 1, row_id*3 + 2]).flatten()
col_id = np.column_stack([col_id*3, col_id*3 + 1, col_id*3 + 2]).flatten()
mass_data = np.repeat(mass_data, 3)

# Create the sparse global mass matrix
global_mass_matrix = csc_matrix((mass_data.flatten(), (row_id, col_id)), shape=(n_vertices*3, n_vertices*3))
global_mass_matrix.sum_duplicates()

In [9]:
# We test that the global mass matrix has correctly been built
def test_global_mass_matrix(global_mass_matrix):
    # We also assume that the velocity field is constant and equal to 1 m/s

    velocity_field_ar = np.repeat([[1,0,0]], n_vertices, axis=0)
    velocity_field_ar = velocity_field_ar.flatten()
    kinetic_energy = 0.5 * np.dot(velocity_field_ar, global_mass_matrix.dot(velocity_field_ar))

    # The total kinetic energy of the object is given by: 0.5 * mass * norm(velocity)^2. In this case it's equal to 0.5 Joules. 
    np.testing.assert_allclose(kinetic_energy, 0.5, rtol=1e-6)

test_global_mass_matrix(global_mass_matrix)

In [10]:
# Create the strain density energy function of a compressible Neo-Hookean material


#-----------------------------------------------------------------------------------
@partial(jax.jit, static_argnums=(4, 5))
def compressible_neo_hookean_strain_density_energy_func(
        u, 
        ref_point_ar, 
        ref_volume_ar, 
        D, 
        C1 = 1, 
        C2 = 1
    ):

    # Compute the point coordinates in the deformed configuration
    def_point_ar = ref_point_ar + u.reshape(-1, 3)

    # Compute the deformation gradient tensor
    cell_def_point_ar = jnp.transpose(def_point_ar[cell_ar], (0, 2, 1))
    F_ar = jnp.matmul(cell_def_point_ar, D)
    F_T_ar = jnp.transpose(F_ar, (0, 2, 1))

    # Compute the right Cauchy-Green tensor
    C_ar = jnp.matmul(F_T_ar, F_ar)

    # Compute the invariants of the right Cauchy-Green tensor
    J_ar = jnp.linalg.det(F_ar)
    I1_ar = jnp.trace(C_ar, axis1=1, axis2=2)

    # Compute the strain density energy function of each cell
    W = C1 * (I1_ar -3 - 2 * jnp.log(J_ar)) + C2 * (J_ar - 1)**2

    # Integrate the strain density energy function over the cell volumes
    E = ref_volume_ar * J_ar * W

    # Sum the contributions of all the cells
    return np.sum(E)
#-----------------------------------------------------------------------------------

# Create a dummy displacement field
def_point_ar = point_ar * np.array([1, 1, 2])
u_ar = jnp.array(def_point_ar - point_ar).flatten()



compressible_neo_hookean_strain_density_energy_func(
    u_ar,
    point_ar, 
    ref_volume_ar, 
    D, 
    C1 = 1, 
    C2 = 1
)

# Compute the gradient of the strain density energy function with respect to the 
# displacement field, which corresponds to the internal forces acting on the object
internal_forces_func = jax.grad(compressible_neo_hookean_strain_density_energy_func, argnums=0)

internal_forces = internal_forces_func(
    u_ar,
    point_ar, 
    ref_volume_ar, 
    D, 
    C1 = 1, 
    C2 = 1
)

print(internal_forces)


[-2.24146692e-02 -2.24146692e-02  2.13747069e-02 ...  1.33573708e-16
  5.37764278e-17 -4.89192020e-16]


In [50]:
# We need to construct the Jacobian of the forces wrt to the displacement field
# We cannot simply call the jax.jacobian function, since the jacobian matrix 
# is very sparse and too large to fit in memory for most meshes.


# This first naive implementation took 15 minutes to compute the Jacobian matrix for 
# a mesh with ~2000 vertices.
def compute_sparse_jacobian_naive(f, x, sparsity_pattern):
    """
    Compute the sparse Jacobian of f at x given a known sparsity pattern.
    
    Args:
        f: Function whose Jacobian is to be computed.
        x: Point at which to compute the Jacobian.
        sparsity_pattern: Sparse matrix indicating the structure of the Jacobian.
    
    Returns:
        A sparse matrix (CSR format) representing the Jacobian.
    """
    # Get sparsity structure
    row_indices, col_indices = sparsity_pattern.nonzero()
    
    # Compute each nonzero element of the Jacobian
    values = []
    for i, j in tqdm(zip(row_indices, col_indices), total=len(row_indices)):
        # Create a standard basis vector e_j
        e_j = jnp.zeros_like(x)
        e_j = e_j.at[j].set(1.0)
        
        # Compute J[i, j] using JVP
        _, J_row = jax.jvp(f, (x,), (e_j,))
        values.append(J_row[i])
    
    # Create a sparse Jacobian matrix
    return csc_matrix((values, (row_indices, col_indices)), shape=(len(f(x)), len(x)))





# We need to know the sparsity pattern of the Jacobian matrix of the internal forces. 
# If 2 nodes are connected by an edge, then their positions affect the elastic forces applied 
# to each other. This means that the Jacobian matrix is non-zero at the positions corresponding
# the indices of the 2 nodes connected by an edge. 

edge_ar = np.vstack(
    [
        face_ar[:, [0, 1]],
        face_ar[:, [1, 2]],
        face_ar[:, [2, 1]],
    ]
)

# Sort the edge array
edge_ar = np.sort(edge_ar, axis=1)

# Remove duplicate edges
edge_ar = np.unique(edge_ar, axis=0)

# This is the adjacency matrix of the mesh
adjacency_matrix_1D_indices = np.vstack(
    [
        edge_ar[:, [0, 1]], 
        edge_ar[:, [1, 0]], 
        np.repeat(np.arange(n_vertices).reshape(-1, 1), 2, axis=1), # Add the diagonal elements
    ]
)

# Create the sparse adjacency matrix
adjacency_matrix_1D = csc_matrix((np.ones(adjacency_matrix_1D_indices.shape[0]), (adjacency_matrix_1D_indices[:, 0], adjacency_matrix_1D_indices[:, 1])), shape=(n_vertices, n_vertices))


# We need to extend the adjacency matrix to account for the 3 components of the displacement field
# At each row, for each pair of indices (i, j) we need to add:
# (i*3 + 0, j*3 + 0), 
# (i*3 + 1, j*3 + 0), 
# (i*3 + 0, j*3 + 1), 

# (i*3 + 1, j*3 + 1),
# (i*3 + 2, j*3 + 0),
# (i*3 + 0, j*3 + 2),

# (i*3 + 2, j*3 + 2),
# (i*3 + 1, j*3 + 2),
# (i*3 + 2, j*3 + 1)


sparsity_pattern_indices = np.vstack(
    [
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3    , adjacency_matrix_1D_indices[:, 1]*3    ]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 1, adjacency_matrix_1D_indices[:, 1]*3    ]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3    , adjacency_matrix_1D_indices[:, 1]*3 + 1]),

        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 1, adjacency_matrix_1D_indices[:, 1]*3 + 1]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 2, adjacency_matrix_1D_indices[:, 1]*3    ]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3    , adjacency_matrix_1D_indices[:, 1]*3 + 2]),

        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 2, adjacency_matrix_1D_indices[:, 1]*3 + 2]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 1, adjacency_matrix_1D_indices[:, 1]*3 + 2]),
        np.column_stack([adjacency_matrix_1D_indices[:, 0]*3 + 2, adjacency_matrix_1D_indices[:, 1]*3 + 1]),
    ]
)


row_id = sparsity_pattern_indices[:, 0]
col_id = sparsity_pattern_indices[:, 1]
data = np.ones(sparsity_pattern_indices.shape[0])

# Create the sparse adjacency matrix
sparsity_pattern = csc_matrix((data, (row_id, col_id)), shape=(n_vertices*3, n_vertices*3))


# Compute the sparse Jacobian of the internal forces
#jac_internal_forces = compute_sparse_jacobian_naive(
#    partial(internal_forces_func, ref_point_ar=point_ar, ref_volume_ar=ref_volume_ar, D=D, C1=1, C2=1),
#    u_ar,
#    adjacency_matrix
#)









In [57]:

def distance_2_coloring(graph):
    """
    Perform distance-2 coloring on the input graph.

    Parameters:
        graph (networkx.Graph): Input graph.

    Returns:
        dict: A dictionary where keys are nodes and values are the assigned colors.
    """
    # Create the square graph (G^2)
    square_graph = nx.Graph()
    square_graph.add_nodes_from(graph.nodes)
    
    for node in graph.nodes:
        # Get neighbors within distance 2
        distance_2_neighbors = set(nx.single_source_shortest_path_length(graph, node, cutoff=2).keys())
        distance_2_neighbors.discard(node)  # Remove self-loops
        for neighbor in distance_2_neighbors:
            square_graph.add_edge(node, neighbor)
    
    # Perform greedy coloring on the square graph
    coloring = nx.coloring.greedy_color(square_graph, strategy="largest_first")
    return coloring

# Example usage
# Create a sample graph
G = nx.Graph()
G.add_edges_from(sparsity_pattern_indices.tolist())
vertex_color_dic = distance_2_coloring(G)

# We need to reverse the dictionary to get a list of vertices for each color
color_vertex_dic = {}
for vertex, color in vertex_color_dic.items():
    color_vertex_dic[color] = color_vertex_dic.get(color, []) + [vertex]

# Transform the list of vertices into a numpy array
for color, vertex_lst in color_vertex_dic.items():
    color_vertex_dic[color] = np.array(vertex_lst)

In [106]:
# We can optimize the computation of the Jacobian by obtaining collecting the entries 
# of the Jacobian matrix for several nodes at once. This way we can compute the Jacobian
# with fewer calls to the jax.jvp function. The trick is to sum the column vectors 
# where there is no entry on the same row. After summing the column vectors, we already 
# know from the sparsity pattern to which column the entry at each row corresponds.

def compute_sparse_jacobian_optimized(f, x, sparsity_pattern, color_vertex_dic):
    """
    Compute the sparse Jacobian of f at x given a known sparsity pattern.
    
    Args:
        f: Function whose Jacobian is to be computed.
        x: Point at which to compute the Jacobian.
        sparsity_pattern: Sparse matrix indicating the structure of the Jacobian.
    
    Returns:
        A sparse matrix (CSR format) representing the Jacobian.
    """


    # Store the indices and entries of the sparse jacobian matrix 
    jacobian_row_id_lst = []
    jacobian_col_id_lst = []
    jacobian_data_lst   = []

    for selected_col_id_ar in tqdm(color_vertex_dic.values(), total=len(color_vertex_dic), desc="Constructing Sparse Jacobian"):

        # Create a vector of zeros
        e_j = jnp.zeros_like(x)

        # Set the values of the vector to 1 at the positions corresponding to the vertices
        e_j = e_j.at[selected_col_id_ar].set(1.0)

        # Compute the Jacobian vector product 
        _, summed_cols_vector = jax.jvp(f, (x,), (e_j,))

        # Now the trick is to find for each row of summed_col_vectors the index of the column where the entry is located in the Jacobian matrix
        # Start making all columns in the adjacency matrix zero except for the selected columns
        diag_matrix = diags(np.array(e_j), format='csc')
        adjacency_matrix_subset =  sparsity_pattern.dot(diag_matrix)

        # Extract the indices non-zero entries of this subset of the adjacency matrix
        subset_row_indices, subset_col_indices = adjacency_matrix_subset.nonzero()

        # The row indices of the Jacobian matrix are the same as the row indices of the adjacency matrix
        jacobian_row_id_lst.extend(subset_row_indices)
        jacobian_col_id_lst.extend(subset_col_indices)
        jacobian_data_lst.extend(summed_cols_vector[subset_row_indices])


    # Create the sparse Jacobian matrix
    return csc_matrix((jacobian_data_lst, (jacobian_row_id_lst, jacobian_col_id_lst)), shape=(len(f(x)), len(x)))


jac_internal_forces_optimized = compute_sparse_jacobian_optimized(
    partial(internal_forces_func, ref_point_ar=point_ar, ref_volume_ar=ref_volume_ar, D=D, C1=1, C2=1),
    u_ar,
    sparsity_pattern,
    color_vertex_dic
)


    



Constructing Sparse Jacobian: 100%|██████████| 93/93 [00:02<00:00, 42.05it/s]


Building the sparse Jacobian matrix
