1. Import packages

In [None]:
import numpy as np
import pandas as pd
import pyvista as pv
import open3d as o3d
import networkx
import scipy.sparse
from scipy.linalg import eigh
from scipy.spatial import cKDTree
import trimesh
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import openpyxl

2. Split mesh (if the file contains meshes of multiple objects)

In [None]:
# ==== User Settings Area ====
input_mesh_path = r"PATH/TO/YOUR/INPUT_MESH.obj"  # Original mesh file path (replace with your actual path)
volume_threshold = 1000.0      # Volume threshold; meshes with larger volume will be excluded
min_faces = 500                # Minimum number of faces; meshes with fewer faces will be skipped
output_base_dir = r"PATH/TO/YOUR/OUTPUT_DIRECTORY"  # Output directory base path (replace accordingly)
# ===========================

# Extract base file name without extension for creating subfolder
base_name = os.path.splitext(os.path.basename(input_mesh_path))[0]

# Create a subdirectory under the output base directory for the current mesh file
output_subdir = os.path.join(output_base_dir, base_name)
os.makedirs(output_subdir, exist_ok=True)

# Load mesh
print(f"Loading mesh from: {input_mesh_path}")
mesh = trimesh.load(input_mesh_path, force='mesh')

# Split into connected components
print("Splitting mesh into connected components...")
components = mesh.split(only_watertight=False)
print(f"Found {len(components)} components.")

saved_count = 0
for i, comp in enumerate(components):
    num_faces = comp.faces.shape[0]

    if num_faces < min_faces:
        print(f"Skipping component {i} (only {num_faces} faces, less than minimum {min_faces})")
        continue

    try:
        vol = comp.volume
    except Exception:
        vol = 0.0
        print(f"Warning: Failed to calculate volume for component {i}, assuming volume=0.")

    if vol > volume_threshold:
        print(f"Skipping component {i} (volume={vol:.2f} > threshold={volume_threshold})")
        continue

    output_path = os.path.join(output_subdir, f"{base_name}_{saved_count}.obj")
    comp.export(output_path)
    print(f"Saved component {i} as {output_path} (faces={num_faces}, volume={vol:.2f})")
    saved_count += 1

print(f"\n✅ Done. {saved_count} components saved to '{output_subdir}'")


3. Load and simplify mesh

In [None]:
# Folder containing all input meshes obtained from last step
input_mesh_folder = r'PATH/TO/YOUR/INPUT_MESH_FOLDER' 

# Output directory
output_root_dir = r'PATH/TO/YOUR/OUTPUT_DIRECTORY'  

simplify_ratio = 0.5  # Target ratio for mesh simplification

# Extract the last folder name from input path for creating subfolder
folder_name = os.path.basename(os.path.normpath(input_mesh_folder))

# Final output folder path
output_dir = os.path.join(output_root_dir, folder_name)
os.makedirs(output_dir, exist_ok=True)


def load_and_preprocess_mesh_no_smooth(filename, simplify_ratio=0.1):
    print(f"\n📂 Loading mesh file: {filename}")
    mesh = trimesh.load(filename)

    if not mesh.is_watertight:
        print("⚠️ Warning: The mesh is not watertight")

    original_faces = len(mesh.faces)
    print(f"Original vertices: {len(mesh.vertices)}, faces: {original_faces}")

    o3d_mesh = o3d.geometry.TriangleMesh()
    o3d_mesh.vertices = o3d.utility.Vector3dVector(mesh.vertices)
    o3d_mesh.triangles = o3d.utility.Vector3iVector(mesh.faces)

    target_faces = int(original_faces * simplify_ratio)
    print(f"🔧 Simplifying mesh to ~{target_faces} faces ({simplify_ratio:.2%} of original)...")
    o3d_mesh = o3d_mesh.simplify_quadric_decimation(target_faces)
    print(f"Simplified vertices: {len(o3d_mesh.vertices)}, faces: {len(o3d_mesh.triangles)}")

    o3d_mesh.remove_duplicated_vertices()
    o3d_mesh.remove_degenerate_triangles()
    o3d_mesh.remove_duplicated_triangles()
    o3d_mesh.remove_non_manifold_edges()

    processed_mesh = trimesh.Trimesh(
        vertices=np.asarray(o3d_mesh.vertices),
        faces=np.asarray(o3d_mesh.triangles),
        process=False
    )

    print(f"✅ Preprocessing complete. Final vertices: {len(processed_mesh.vertices)}, faces: {len(processed_mesh.faces)}")
    return processed_mesh


for filename in os.listdir(input_mesh_folder):
    if not filename.lower().endswith('.obj'):
        continue

    input_file_path = os.path.join(input_mesh_folder, filename)
    base_name = os.path.splitext(filename)[0]
    output_file_path = os.path.join(output_dir, f"{base_name}_simplified.obj")

    print(f"Processing: {input_file_path}")
    simplified_mesh = load_and_preprocess_mesh_no_smooth(input_file_path, simplify_ratio=simplify_ratio)
    simplified_mesh.export(output_file_path)
    print(f"💾 Saved simplified mesh to: {output_file_path}")

3. Calculate invagination ratio

In [None]:
# ==== User-configurable parameters ====
input_dir_cal = r'PATH/TO/YOUR/INPUT_MESH_FOLDER'   # Folder containing all simplified meshes obtained from the previous step
output_dir_cal = r'PATH/TO/YOUR/OUTPUT_DIRECTORY'   # Directory to save curvature analysis results
use_third_ring = False                              # Whether to expand the neighbor search to third-ring vertices (for larger patch fitting)
smooth_iters_first = 5                              # Number of smoothing iterations before bridging disconnected concave regions
bridge_length = 2                                   # Max vertex distance allowed to connect two concave regions
smooth_iters_second = 5                             # Number of smoothing iterations after bridging to refine classification
save_type_txt = True                                # Whether to save curvature classification results (concave/convex) as a .txt file
# =====================================

os.makedirs(output_dir_cal, exist_ok=True)

def compute_vertex_normals(mesh):
    return mesh.vertex_normals

def vertex_neighbors(mesh):
    nv = len(mesh.vertices)
    neighbors = [[] for _ in range(nv)]
    for face in mesh.faces:
        for i in range(3):
            v = face[i]
            nbr1 = face[(i+1)%3]
            nbr2 = face[(i+2)%3]
            if nbr1 not in neighbors[v]:
                neighbors[v].append(nbr1)
            if nbr2 not in neighbors[v]:
                neighbors[v].append(nbr2)
    return neighbors

def vector_rotation_matrix(v):
    v = v / np.linalg.norm(v)
    target = np.array([-1, 0, 0])
    axis = np.cross(v, target)
    axis_len = np.linalg.norm(axis)
    if axis_len < 1e-8:
        return np.eye(3), np.eye(3)
    axis = axis / axis_len
    angle = np.arccos(np.clip(np.dot(v, target), -1, 1))
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    M = np.eye(3) + np.sin(angle)*K + (1 - np.cos(angle))*(K @ K)
    return M, M.T

def eig2(Dxx, Dxy, Dyy):
    H = np.array([[Dxx, Dxy], [Dxy, Dyy]])
    vals, vecs = eigh(H)
    idx = np.argsort(np.abs(vals))
    return vals[idx[0]], vals[idx[1]], vecs[:, idx[0]], vecs[:, idx[1]]

def patch_curvature(mesh, use_third=False, smooth_iters_first=5, bridge_length=2, smooth_iters_second=5):
    normals = compute_vertex_normals(mesh)
    neighbors = vertex_neighbors(mesh)
    nv = len(mesh.vertices)
    Lambda1 = np.zeros(nv)
    Lambda2 = np.zeros(nv)
    Dir1 = np.zeros((nv,3))
    Dir2 = np.zeros((nv,3))

    for i in range(nv):
        if not use_third:
            nbrs = set()
            for nb in neighbors[i]:
                nbrs.update(neighbors[nb])
        else:
            nbrs = set()
            for nb in neighbors[i]:
                for nnb in neighbors[nb]:
                    nbrs.update(neighbors[nnb])
        nbrs = list(nbrs)

        Ve = mesh.vertices[nbrs]
        M, Minv = vector_rotation_matrix(normals[i])
        We = (Ve - mesh.vertices[i]) @ Minv.T
        f = We[:,0]
        x = We[:,1]
        y = We[:,2]

        FM = np.column_stack([x**2, y**2, x*y, x, y, np.ones_like(x)])
        abcdef, _, _, _ = np.linalg.lstsq(FM, f, rcond=None)
        a,b,c = abcdef[0], abcdef[1], abcdef[2]
        Dxx = 2*a
        Dxy = c
        Dyy = 2*b
        lam1, lam2, I1, I2 = eig2(Dxx, Dxy, Dyy)
        Dir1[i,:] = (np.array([0,*I1]) @ M).flatten()
        Dir2[i,:] = (np.array([0,*I2]) @ M).flatten()
        Lambda1[i] = lam1
        Lambda2[i] = lam2

    # Classification based on the sign of principal curvatures
    eps = 1e-6
    Type = []
    for l1, l2 in zip(Lambda1, Lambda2):
        sign1 = 0 if abs(l1) < eps else np.sign(l1)
        sign2 = 0 if abs(l2) < eps else np.sign(l2)
        if sign1 == 0 or sign2 == 0:
            Type.append('convex')
        elif sign1 == sign2:
            Type.append('convex')
        else:
            Type.append('concave')

    def smooth_classification(Type_list, neighbors, iterations):
        Type_smoothed = Type_list.copy()
        for _ in range(iterations):
            Type_new = Type_smoothed.copy()
            for i, t in enumerate(Type_smoothed):
                votes = {'convex':0, 'concave':0}
                votes[t] += 1
                for nb in neighbors[i]:
                    votes[Type_smoothed[nb]] += 1
                Type_new[i] = 'convex' if votes['convex'] >= votes['concave'] else 'concave'
            Type_smoothed = Type_new
        return Type_smoothed

    def bridge_disconnected_concave_regions(Type_list, neighbors, max_bridge_length):
        from scipy.sparse import csr_matrix
        from scipy.sparse.csgraph import dijkstra
        n = len(Type_list)
        row, col, data = [], [], []
        for i in range(n):
            for nb in neighbors[i]:
                row.append(i)
                col.append(nb)
                data.append(1)
        graph = csr_matrix((data, (row, col)), shape=(n, n))
        concave_indices = np.array([i for i,t in enumerate(Type_list) if t=='concave'])
        if len(concave_indices) == 0:
            return Type_list.copy()
        dist_matrix, _ = dijkstra(csgraph=graph, directed=False, return_predecessors=True, indices=concave_indices)
        nearest_dist = np.min(dist_matrix, axis=0)
        nearest_concave = concave_indices[np.argmin(dist_matrix, axis=0)]
        second_dist_matrix = dist_matrix.copy()
        for v in range(n):
            first_closest_row = np.argmin(dist_matrix[:,v])
            second_dist_matrix[first_closest_row,v] = np.inf
        second_nearest_dist = np.min(second_dist_matrix, axis=0)
        second_nearest_concave = concave_indices[np.argmin(second_dist_matrix, axis=0)]

        Type_new = Type_list.copy()
        for i in range(n):
            if Type_list[i] != 'convex':
                continue
            d1 = nearest_dist[i]
            d2 = second_nearest_dist[i]
            c1 = nearest_concave[i]
            c2 = second_nearest_concave[i]
            if c1 != c2 and d1 <= max_bridge_length and d2 <= max_bridge_length:
                Type_new[i] = 'concave'
        return Type_new

    Type_smoothed = smooth_classification(Type, neighbors, iterations=smooth_iters_first)
    Type_bridged = bridge_disconnected_concave_regions(Type_smoothed, neighbors, max_bridge_length=bridge_length)
    Type_final = smooth_classification(Type_bridged, neighbors, iterations=smooth_iters_second)
    return Lambda1, Lambda2, Dir1, Dir2, Type_final

def compute_invagination_ratio_v2(mesh, Type):
    concave_vertices = set(i for i, t in enumerate(Type) if t == 'concave')
    concave_faces = [face for face in mesh.faces if any(v in concave_vertices for v in face)]
    concave_faces = np.array(concave_faces)
    if concave_faces.shape[0] == 0:
        return 0.0
    if concave_faces.ndim != 2 or concave_faces.shape[1] != 3:
        raise ValueError(f"concave_faces shape invalid: {concave_faces.shape}")
    concave_mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=concave_faces, process=False)
    concave_area = concave_mesh.area
    total_area = mesh.area
    return concave_area / total_area

def save_type_to_txt(filename, Type):
    with open(filename, 'w') as f:
        for t in Type:
            f.write(t + '\n')

# ==== Main program ====
mesh_files = [f for f in os.listdir(input_dir_cal) if f.lower().endswith('.obj')]
invagination_results_list = []

print(f"\n🔍 Found {len(mesh_files)} mesh files\n")

for filename in mesh_files:
    file_path = os.path.join(input_dir_cal, filename)
    mesh = trimesh.load(file_path, process=False)
    Lambda1, Lambda2, Dir1, Dir2, Type_final = patch_curvature(
        mesh,
        use_third=use_third_ring,
        smooth_iters_first=smooth_iters_first,
        bridge_length=bridge_length,
        smooth_iters_second=smooth_iters_second
    )
    inv_ratio = compute_invagination_ratio_v2(mesh, Type_final)
    invagination_results_list.append([filename, inv_ratio])
    print(f"✅ {filename}: Invagination ratio = {inv_ratio:.4f}")

    if save_type_txt:
        txt_path = os.path.join(output_dir_cal, os.path.splitext(filename)[0] + '_type.txt')
        save_type_to_txt(txt_path, Type_final)

# === Save results ===
excel_save_path = os.path.join(output_dir_cal, "invagination_results.xlsx")
results_df = pd.DataFrame(invagination_results_list, columns=["filename", "invagination_ratio"])
results_df.to_excel(excel_save_path, index=False)
print(f"\n📄 All invagination ratios saved to: {excel_save_path}")
print("\n✅ All done.")

4. Visualisation

In [None]:
# ==== User-adjustable parameters ====
input_dir_cal = r"PATH/TO/YOUR/SIMPLIFIED_MESH_FOLDER"  # Folder containing input meshes
output_dir_cal = r"PATH/TO/YOUR/CURVATURE_RESULTS_FOLDER"  # Folder to save results
output_dir_vis = r"PATH/TO/YOUR/VISUALISATION_FOLDER"  # Folder to save visualisation images
use_pyvista = False  # Whether to enable PyVista plotting
smooth_iters_first = 5
bridge_length = 2
smooth_iters_second = 5
use_third_ring = False
save_type_txt = True
# ===========================

os.makedirs(output_dir_cal, exist_ok=True)
os.makedirs(output_dir_vis, exist_ok=True)

def plot_curvature(mesh, values, title, save_path=None):
    V = mesh.vertices
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    p = ax.scatter(V[:,0], V[:,1], V[:,2], c=values, cmap='jet', s=5)
    fig.colorbar(p, ax=ax, shrink=0.5)
    ax.set_box_aspect([1,1,1])
    plt.title(title)
    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
    else:
        plt.show()

def plot_directions(mesh, Dir1, Dir2, scale=0.1, save_path=None):
    V = mesh.vertices
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.quiver(V[:,0], V[:,1], V[:,2], Dir1[:,0], Dir1[:,1], Dir1[:,2],
              length=scale, color='green', normalize=True, label='Dir1')
    ax.quiver(V[:,0], V[:,1], V[:,2], Dir2[:,0], Dir2[:,1], Dir2[:,2],
              length=scale, color='red', normalize=True, label='Dir2')
    ax.set_box_aspect([1,1,1])
    plt.title('Principal Curvature Directions')
    ax.legend()
    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
    else:
        plt.show()

def plot_type(mesh, Type, save_path=None):
    V = mesh.vertices
    colors = {'concave': 'red', 'convex': 'blue'}
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    xlim = (np.min(V[:,0]), np.max(V[:,0]))
    ylim = (np.min(V[:,1]), np.max(V[:,1]))
    zlim = (np.min(V[:,2]), np.max(V[:,2]))

    for t in ['concave', 'convex']:
        idx = [i for i, v in enumerate(Type) if v == t]
        ax.scatter(V[idx,0], V[idx,1], V[idx,2], label=t.capitalize(), color=colors[t], s=5)

    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_zlim(zlim)

    ax.legend()
    ax.set_title("Curvature Classification (Concave vs Convex)")
    ax.set_box_aspect([1,1,1])
    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
    else:
        plt.show()

def visualize_curvature(mesh, Cmean, Type, save_path=None):
    faces_pv = np.hstack([np.full((len(mesh.faces), 1), 3), mesh.faces]).astype(np.int64)
    pv_mesh = pv.PolyData(mesh.vertices, faces_pv)
    pv_mesh.point_data['MeanCurvature'] = Cmean
    concave_numeric = np.array([1 if t == 'concave' else 0 for t in Type])
    pv_mesh.point_data['Concave'] = concave_numeric

    plotter = pv.Plotter(shape=(1,2), window_size=[1200, 600])

    plotter.subplot(0,0)
    plotter.add_mesh(pv_mesh, scalars='MeanCurvature', cmap='coolwarm', show_scalar_bar=True)
    plotter.add_text("Mean Curvature", font_size=14)
    plotter.camera_position = 'xy'

    plotter.subplot(0,1)
    cmap = ['blue', 'red']
    plotter.add_mesh(pv_mesh, scalars='Concave', cmap=cmap, clim=[0,1], show_scalar_bar=True)
    plotter.add_text("Concave (red) vs Convex (blue)", font_size=14)
    plotter.camera_position = 'xy'

    if save_path:
        plotter.show(screenshot=save_path)
        plotter.close()
    else:
        plotter.show()

for filename in mesh_files:
    print(f"Processing {filename} ...")
    file_path = os.path.join(input_dir_cal, filename)
    mesh = trimesh.load(file_path, process=False)

    Lambda1, Lambda2, Dir1, Dir2, Type_final = patch_curvature(
        mesh,
        use_third=use_third_ring,
        smooth_iters_first=smooth_iters_first,
        bridge_length=bridge_length,
        smooth_iters_second=smooth_iters_second
    )

    inv_ratio = compute_invagination_ratio_v2(mesh, Type_final)
    print(f"  Invagination ratio: {inv_ratio:.4f}")

    if save_type_txt:
        txt_path = os.path.join(output_dir_cal, os.path.splitext(filename)[0] + '_type.txt')
        save_type_to_txt(txt_path, Type_final)

    base_name = os.path.splitext(filename)[0]
    vis_mean_path = os.path.join(output_dir_vis, base_name + '_mean_curvature.png')
    vis_gaussian_path = os.path.join(output_dir_vis, base_name + '_gaussian_curvature.png')
    vis_dir_path = os.path.join(output_dir_vis, base_name + '_directions.png')
    vis_type_path = os.path.join(output_dir_vis, base_name + '_type.png')
    vis_pv_path = os.path.join(output_dir_vis, base_name + '_pyvista.png')

    # ======= Only plot type image; others commented out =======
    # plot_curvature(mesh, (Lambda1+Lambda2)/2, "Mean Curvature", save_path=vis_mean_path)
    # plot_curvature(mesh, Lambda1*Lambda2, "Gaussian Curvature", save_path=vis_gaussian_path)
    # plot_directions(mesh, Dir1, Dir2, save_path=vis_dir_path)
    plot_type(mesh, Type_final, save_path=vis_type_path)

    if use_pyvista:
        # visualize_curvature(mesh, (Lambda1+Lambda2)/2, Type_final, save_path=vis_pv_path)
        pass

print("\nAll processing completed!")


5. Visulisation using PyVista for specified file

In [None]:
# ----- Curvature Calculation Functions -----

def compute_vertex_normals(mesh):
    return mesh.vertex_normals

def vertex_neighbors(mesh):
    nv = len(mesh.vertices)
    neighbors = [[] for _ in range(nv)]
    for face in mesh.faces:
        for i in range(3):
            v = face[i]
            nbr1 = face[(i+1)%3]
            nbr2 = face[(i+2)%3]
            if nbr1 not in neighbors[v]:
                neighbors[v].append(nbr1)
            if nbr2 not in neighbors[v]:
                neighbors[v].append(nbr2)
    return neighbors

def vector_rotation_matrix(v):
    v = v / np.linalg.norm(v)
    target = np.array([-1, 0, 0])
    axis = np.cross(v, target)
    axis_len = np.linalg.norm(axis)
    if axis_len < 1e-8:
        return np.eye(3), np.eye(3)
    axis = axis / axis_len
    angle = np.arccos(np.clip(np.dot(v, target), -1, 1))
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    M = np.eye(3) + np.sin(angle)*K + (1 - np.cos(angle))*(K @ K)
    return M, M.T

def eig2(Dxx, Dxy, Dyy):
    H = np.array([[Dxx, Dxy],
                  [Dxy, Dyy]])
    vals, vecs = eigh(H)
    idx = np.argsort(np.abs(vals))
    return vals[idx[0]], vals[idx[1]], vecs[:, idx[0]], vecs[:, idx[1]]

def patch_curvature(mesh, use_third=False, smooth_iters_first=5, bridge_length=2, smooth_iters_second=5):
    normals = compute_vertex_normals(mesh)
    neighbors = vertex_neighbors(mesh)
    nv = len(mesh.vertices)

    Lambda1 = np.zeros(nv)
    Lambda2 = np.zeros(nv)
    Dir1 = np.zeros((nv,3))
    Dir2 = np.zeros((nv,3))

    for i in range(nv):
        if not use_third:
            nbrs = set()
            for nb in neighbors[i]:
                nbrs.update(neighbors[nb])
            nbrs = list(nbrs)
        else:
            nbrs = set()
            for nb in neighbors[i]:
                for nnb in neighbors[nb]:
                    nbrs.update(neighbors[nnb])
            nbrs = list(nbrs)

        Ve = mesh.vertices[nbrs]
        M, Minv = vector_rotation_matrix(normals[i])
        We = (Ve - mesh.vertices[i]) @ Minv.T
        f = We[:,0]
        x = We[:,1]
        y = We[:,2]

        FM = np.column_stack([x**2, y**2, x*y, x, y, np.ones_like(x)])
        abcdef, _, _, _ = np.linalg.lstsq(FM, f, rcond=None)
        a,b,c = abcdef[0], abcdef[1], abcdef[2]

        Dxx = 2*a
        Dxy = c
        Dyy = 2*b

        lam1, lam2, I1, I2 = eig2(Dxx, Dxy, Dyy)
        Dir1[i,:] = (np.array([0,*I1]) @ M).flatten()
        Dir2[i,:] = (np.array([0,*I2]) @ M).flatten()
        Lambda1[i] = lam1
        Lambda2[i] = lam2

    eps = 1e-6
    Type = []
    for l1, l2 in zip(Lambda1, Lambda2):
        sign1 = 0 if abs(l1) < eps else np.sign(l1)
        sign2 = 0 if abs(l2) < eps else np.sign(l2)
        if sign1 == 0 or sign2 == 0:
            Type.append('convex')
        elif sign1 == sign2:
            Type.append('convex')
        else:
            Type.append('concave')

    def smooth_classification(Type_list, neighbors, iterations):
        Type_smoothed = Type_list.copy()
        for _ in range(iterations):
            Type_new = Type_smoothed.copy()
            for i, t in enumerate(Type_smoothed):
                votes = {'convex':0, 'concave':0}
                votes[t] += 1
                for nb in neighbors[i]:
                    votes[Type_smoothed[nb]] += 1
                Type_new[i] = 'convex' if votes['convex'] >= votes['concave'] else 'concave'
            Type_smoothed = Type_new
        return Type_smoothed

    Type_smoothed = smooth_classification(Type, neighbors, iterations=smooth_iters_first)

    def bridge_disconnected_concave_regions(Type_list, neighbors, max_bridge_length):
        from scipy.sparse import csr_matrix
        from scipy.sparse.csgraph import dijkstra

        n = len(Type_list)
        row, col, data = [], [], []
        for i in range(n):
            for nb in neighbors[i]:
                row.append(i)
                col.append(nb)
                data.append(1)
        graph = csr_matrix((data, (row, col)), shape=(n, n))

        concave_indices = np.array([i for i,t in enumerate(Type_list) if t=='concave'])
        if len(concave_indices) == 0:
            return Type_list.copy()

        dist_matrix, _ = dijkstra(csgraph=graph, directed=False, return_predecessors=True, indices=concave_indices)

        nearest_dist = np.min(dist_matrix, axis=0)
        nearest_concave = concave_indices[np.argmin(dist_matrix, axis=0)]

        second_dist_matrix = dist_matrix.copy()
        for v in range(n):
            first_closest_row = np.argmin(dist_matrix[:,v])
            second_dist_matrix[first_closest_row,v] = np.inf
        second_nearest_dist = np.min(second_dist_matrix, axis=0)
        second_nearest_concave = concave_indices[np.argmin(second_dist_matrix, axis=0)]

        Type_new = Type_list.copy()
        for i in range(n):
            if Type_list[i] != 'convex':
                continue
            d1 = nearest_dist[i]
            d2 = second_nearest_dist[i]
            c1 = nearest_concave[i]
            c2 = second_nearest_concave[i]

            if c1 != c2 and d1 <= max_bridge_length and d2 <= max_bridge_length:
                Type_new[i] = 'concave'
        return Type_new

    Type_bridged = bridge_disconnected_concave_regions(Type_smoothed, neighbors, max_bridge_length=bridge_length)
    Type_final = smooth_classification(Type_bridged, neighbors, iterations=smooth_iters_second)

    return Lambda1, Lambda2, Dir1, Dir2, Type_final

# ----- PyVista Visualisation -----

def visualize_curvature(mesh, Cmean, Type):
    faces_pv = np.hstack([np.full((len(mesh.faces), 1), 3), mesh.faces]).astype(np.int64)
    pv_mesh = pv.PolyData(mesh.vertices, faces_pv)

    pv_mesh.point_data['MeanCurvature'] = Cmean
    concave_numeric = np.array([1 if t == 'concave' else 0 for t in Type])
    pv_mesh.point_data['Concave'] = concave_numeric

    plotter = pv.Plotter(shape=(1,2), window_size=[1200, 600])

    plotter.subplot(0,0)
    plotter.add_mesh(pv_mesh, scalars='MeanCurvature', cmap='coolwarm', show_scalar_bar=True)
    plotter.add_text("Mean Curvature", font_size=14)
    plotter.camera_position = 'xy'

    plotter.subplot(0,1)
    cmap = ['blue', 'red']
    plotter.add_mesh(pv_mesh, scalars='Concave', cmap=cmap, clim=[0,1], show_scalar_bar=True)
    plotter.add_text("Concave (red) vs Convex (blue)", font_size=14)
    plotter.camera_position = 'xy'

    plotter.show()

# ===== Main Execution Example =====

if __name__ == "__main__":
    mesh_file = r"C:\\Path\\To\\Your\\Mesh.obj"  # Replace with your mesh file path
    mesh = trimesh.load(mesh_file, process=False)

    Lambda1, Lambda2, Dir1, Dir2, Type_final = patch_curvature(
        mesh,
        use_third=False,
        smooth_iters_first=5,
        bridge_length=2,
        smooth_iters_second=5
    )

    Cmean = (Lambda1 + Lambda2) / 2
    visualize_curvature(mesh, Cmean, Type_final)
