In [1]:
from dataclasses import dataclass
import scipy.io
from scipy.sparse import csc_matrix
import numpy as np


@dataclass
class MeshHyperelasticity3D:
    # mesh attributes
    dim: int
    level: int
    nn: int
    ne: int
    elems2nodes: np.ndarray
    nodes2coord: np.ndarray
    volumes: np.ndarray
    dphi: list
    Hstr: csc_matrix
    nodesDirichlet: list
    nodesMinim: np.ndarray
    dofsDirichlet: np.ndarray
    dofsMinim: np.ndarray
    dofsMinim_local: np.ndarray
    bfaces2elems: np.ndarray
    bfaces2nodes: np.ndarray

    # params attributes
    lx: float
    ly: float
    lz: float
    E: float
    nu: float
    turning: int
    timeSteps: int
    # draw: list
    visualizeLevels: np.ndarray
    showFullDirichlet: int
    showDirichletTwo: int
    showDirichletOne: int
    # graphs: list
    azimuth: int
    elevation: int
    freq: int
    animations_count: int
    delay: float
    delay_first: float
    delay_last: float
    epsFDSS: float
    max_iters: int
    disp: str
    tf: float
    nbfn: int
    T: int
    lambda_: float
    mu: float
    K: float
    C1: float
    D1: float
    evaluation: int


def load_mesh_hyperelasticity_3d(level: int) -> MeshHyperelasticity3D:
    filename = f"hyperelasticity_mesh_level_{level}.mat"
    data = scipy.io.loadmat(filename)

    # Extracting mesh and params from the loaded data
    mesh_data = data['mesh'][0, 0]
    params_data = data['params'][0, 0]

    # Creating and returning the MeshHyperelasticity3D instance
    return MeshHyperelasticity3D(
        # mesh attributes
        dim=mesh_data[0][0, 0],
        level=mesh_data[1][0, 0],
        nn=mesh_data[2][0, 0],
        ne=mesh_data[3][0, 0],
        elems2nodes=mesh_data[4] - 1,
        nodes2coord=mesh_data[5],
        volumes=mesh_data[6],
        dphi=mesh_data[7][0].tolist(),
        Hstr=mesh_data[8],
        nodesDirichlet=mesh_data[9][0].tolist(),
        nodesMinim=mesh_data[10] - 1,
        dofsDirichlet=mesh_data[11] - 1,
        dofsMinim=mesh_data[12] - 1,
        dofsMinim_local=mesh_data[13] - 1,
        bfaces2elems=mesh_data[14] - 1,
        bfaces2nodes=mesh_data[15] - 1,
        # params attributes
        lx=params_data[0][0, 0],
        ly=params_data[1][0, 0],
        lz=params_data[2][0, 0],
        E=params_data[3][0, 0],
        nu=params_data[4][0, 0],
        turning=params_data[5][0, 0],
        timeSteps=params_data[6][0, 0],
        # draw=params_data[7].tolist(),
        visualizeLevels=params_data[8],
        showFullDirichlet=params_data[9][0, 0],
        showDirichletTwo=params_data[10][0, 0],
        showDirichletOne=params_data[11][0, 0],
        # graphs=params_data[12].tolist(),
        azimuth=params_data[13][0, 0],
        elevation=params_data[14][0, 0],
        freq=params_data[15][0, 0],
        animations_count=params_data[16][0, 0],
        delay=params_data[17][0, 0],
        delay_first=params_data[18][0, 0],
        delay_last=params_data[19][0, 0],
        epsFDSS=params_data[20][0, 0],
        max_iters=params_data[21][0, 0],
        disp=params_data[22][0],
        tf=params_data[23][0, 0],
        nbfn=params_data[24][0, 0],
        T=params_data[25][0, 0],
        lambda_=params_data[26][0, 0],
        mu=params_data[27][0, 0],
        K=params_data[28][0, 0],
        C1=params_data[29][0, 0],
        D1=params_data[30][0, 0],
        evaluation=params_data[31][0, 0]
    )


# Example usage:
level = 1
data = load_mesh_hyperelasticity_3d(level)
print(data.elems2nodes.shape)
print(data.nodes2coord.shape)
print(data.volumes.shape)

(1920, 4)
(729, 3)
(1920, 1)


In [2]:
from jax import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp


def energy(u, u0, dofsMinim, elems2nodes, dphix, dphiy, dphiz, vol, C1, D1):
    v = jnp.array(u0, dtype=jnp.float64)
    v = v.at[dofsMinim].set(u)
    vx = v[0::3][elems2nodes]
    vy = v[1::3][elems2nodes]
    vz = v[2::3][elems2nodes]
    
    G11 = jnp.sum(vx * dphix, axis=1)
    G12 = jnp.sum(vx * dphiy, axis=1)
    G13 = jnp.sum(vx * dphiz, axis=1)
    G21 = jnp.sum(vy * dphix, axis=1)
    G22 = jnp.sum(vy * dphiy, axis=1)
    G23 = jnp.sum(vy * dphiz, axis=1)
    G31 = jnp.sum(vz * dphix, axis=1)
    G32 = jnp.sum(vz * dphiy, axis=1)
    G33 = jnp.sum(vz * dphiz, axis=1)
    
    I1 = G11**2 + G12**2 + G13**2 + G21**2 + G22**2 + G23**2 + G31**2 + G32**2 + G33**2
    det = G11 * G22 * G33 - G11 * G23 * G32 - G12 * G21 * G33 + G12 * G23 * G31 + G13 * G21 * G32 - G13 * G22 * G31
    W = C1 * (I1 - 3 - 2 * jnp.log(det)) + D1 * (det - 1)**2
    return jnp.sum(W * vol)
    

In [3]:
u0 = jnp.array(data.nodes2coord, dtype=jnp.float64).ravel()
alpha = 0.10472
nodes = np.where(data.nodes2coord[:, 0] == data.lx)[0]
u0 = u0.at[nodes * 3 + 1].set(np.cos(alpha) * data.nodes2coord[nodes, 1] + np.sin(alpha) * data.nodes2coord[nodes, 2])
u0 = u0.at[nodes * 3 + 2].set(-np.sin(alpha) * data.nodes2coord[nodes, 1] + np.cos(alpha) * data.nodes2coord[nodes, 2])

I0000 00:00:1699015506.339054  209603 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-11-03 13:45:06.355175: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
dofsMinim = jnp.array(data.dofsMinim.ravel())
u = u0[dofsMinim]
elems2nodes = jnp.array(data.elems2nodes)
dphix = jnp.array(data.dphi[0], dtype=jnp.float64)
dphiy = jnp.array(data.dphi[1], dtype=jnp.float64)
dphiz = jnp.array(data.dphi[2], dtype=jnp.float64)
vol = jnp.array(data.volumes.ravel(), dtype=jnp.float64)
C1 = data.C1
D1 = data.D1

In [5]:
# import jaxes jit
from jax import jit, grad
energy_jit = jit(energy)

dfun = jit(grad(energy, argnums=0))

In [6]:
e = energy_jit(u, u0, dofsMinim, elems2nodes, dphix, dphiy, dphiz, vol, C1, D1)
e

Array(0.28153964, dtype=float64)

In [None]:
g = dfun(u, u0, dofsMinim, elems2nodes, dphix, dphiy, dphiz, vol, C1, D1)

In [None]:
import jax


def ff(x):
    return energy_jit(jnp.array(x.ravel()), u0, dofsMinim, elems2nodes, dphix, dphiy, dphiz, vol, C1, D1)
    
def dff(x):
    g = dfun(jnp.array(x.ravel()), u0, dofsMinim, elems2nodes, dphix, dphiy, dphiz, vol, C1, D1)
    return np.asfortranarray(np.array(g.ravel()))

In [None]:
# import minimize from scipy
from scipy.optimize import minimize
u_np = np.asfortranarray(np.array(u.ravel()))
x = minimize(ff, np.asfortranarray(np.array(u.ravel())), jac=dff,method='Newton-CG', tol=1e-9, options={'maxiter': 100000, 'maxfun': 10000000})
x

In [None]:
res = u0.at[dofsMinim].set(x.x)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create a 3D figure
fig = plt.figure(figsize=(20, 20))
ax : Axes3D = fig.add_subplot(projection='3d') # type: ignore


ax.view_init(elev=0, azim=90)

# Plot the wire mesh
ax.plot_trisurf(res[::3], res[1::3], res[2::3], triangles=data.bfaces2nodes, color = "b", edgecolor='k', linewidth=0.5, antialiased=True, shade=True)

# The closest thing to "equal" scaling:
# Scale the axes equally
data_ranges = [np.ptp(a) for a in [res[::3], res[1::3], res[2::3]]]

# using scatter plot plot two 3d points [0.2 0.05 0] AND [0.2 -0.05 0]
ax.scatter([0.2, 0.2], [0, 0],[0.05, -0.05], c='r', s=100)


ax.set_box_aspect([0.4, 0.08, 0.08])  # Aspect ratio is 1:1:1

ax.grid(False)
# Hide the axes spines
ax.set_axis_off()

# Reduce whitespace around the plot and make the layout tight


ax.set_xlim(ax.get_xlim()[::-1])

plt.tight_layout()
# Show the plot
plt.show()

In [None]:
res_u = []

for ii in range(2,4*48*10):
    u0 = jnp.array(res).ravel()
    alpha = 0.10472 * ii
    nodes = np.where(data.nodes2coord[:, 0] == data.lx)[0]
    u0 = u0.at[nodes * 3 + 1].set(np.cos(alpha) * data.nodes2coord[nodes, 1] + np.sin(alpha) * data.nodes2coord[nodes, 2])
    u0 = u0.at[nodes * 3 + 2].set(-np.sin(alpha) * data.nodes2coord[nodes, 1] + np.cos(alpha) * data.nodes2coord[nodes, 2])
    u = u0[dofsMinim]
    u_np = np.asfortranarray(np.array(u.ravel()))
    x = minimize(
        ff,
        np.asfortranarray(
            np.array(
                u.ravel())),
        jac=dff,
        method='Newton-CG',
        tol=1e-7)
    res = u0.at[dofsMinim].set(x.x)
    res_u.append(x.fun)
    
    print(x)
    
    # Create a 3D figure
    fig = plt.figure(figsize=(20, 20))
    ax: Axes3D = fig.add_subplot(projection='3d')  # type: ignore


    ax.view_init(elev=0, azim=90)

    # Plot the wire mesh
    ax.plot_trisurf(res[::3], res[1::3], res[2::3], triangles=data.bfaces2nodes,
                    color="b", edgecolor='k', linewidth=0.5, antialiased=True, shade=True)

    # The closest thing to "equal" scaling:
    # Scale the axes equally
    data_ranges = [np.ptp(a) for a in [res[::3], res[1::3], res[2::3]]]
    ax.scatter([0.2, 0.2], [0, 0],[0.05, -0.05], c='r', s=100)


    ax.set_box_aspect([0.4, 0.08, 0.08])  # Aspect ratio is 1:1:1
    ax.grid(False)
    # Hide the axes spines
    ax.set_axis_off()

    # Reduce whitespace around the plot and make the layout tight

    ax.set_xlim(ax.get_xlim()[::-1])

    plt.tight_layout()
    # Show the plot
    plt.show()
    # save the figere as png
    fig.savefig(f"figs/fig_{ii-1}.png")

In [None]:
import cv2
import os
import glob

# Folder containing the PNG images
image_folder = 'figs'
video_name = 'output_video.avi'

images = glob.glob(os.path.join(image_folder, '*.png'))
#images.sort()  # Sort the images alphabetically

# Determine the width and height from the first image
frame = cv2.imread(images[0])
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
video = cv2.VideoWriter(video_name, fourcc, 30, (width, height))

for image in images:
    video.write(cv2.imread(image))

cv2.destroyAllWindows()
video.release()

print(f'Video {video_name} is created with {len(images)} frames at 30fps.')

In [None]:
plt.plot(res_u)

In [None]:

# Create a 3D figure
fig = plt.figure(figsize=(20, 20))
ax: Axes3D = fig.add_subplot(projection='3d')  # type: ignore


ax.view_init(elev=0, azim=90)

# Plot the wire mesh
ax.plot_trisurf(res[::3], res[1::3], res[2::3], triangles=data.bfaces2nodes,
                color="b", edgecolor='k', linewidth=0.5, antialiased=True, shade=True)

# The closest thing to "equal" scaling:
# Scale the axes equally
data_ranges = [np.ptp(a) for a in [res[::3], res[1::3], res[2::3]]]
ax.set_box_aspect(data_ranges)  # Aspect ratio is 1:1:1

ax.grid(False)
# Hide the axes spines
ax.set_axis_off()

# Reduce whitespace around the plot and make the layout tight

ax.set_xlim(ax.get_xlim()[::-1])

plt.tight_layout()
# Show the plot
plt.show()
# save the figere as png


In [None]:
data_ranges