In [1]:
#!pip install pyvista
#!pip install pyvista[jupyter]
#!pip install pyvistaqt
import imageio, pickle, uuid, os, time, shutil, tarfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyvista as pv
from pyvistaqt import BackgroundPlotter
from skimage.measure import marching_cubes
import nibabel as nib
from nibabel.processing import resample_from_to
from scipy.optimize import brentq
from scipy.spatial import cKDTree, transform, ConvexHull, SphericalVoronoi
from dataclasses import dataclass, field
from pathlib import Path
from scipy.optimize import differential_evolution
from multiprocessing.dummy import Pool as ThreadPool
from tqdm.notebook import tqdm
pv.set_jupyter_backend('none')    # if you want the trame-based UI
import openxlab.dataset as oxlds
import torch


with open("C:/Users/colli/Downloads/filelist.txt", "r", encoding='utf-8') as f:
    filelist_str = f.read()
filelist = filelist_str.split('|')
filelist = filelist[7::3]
filelist = [i.strip() for i in filelist]

filedict = {}
for path in filelist:  # e.g. ['/raw/a1/mesh/data1.obj', ...]
    parts = path.strip("/").split("/")  # ['raw', 'a1', 'mesh', 'data1.obj']
    current = filedict

    for part in parts[:-1]:  # Traverse all but the last part
        current = current.setdefault(part, {})  # Create if missing

    # Add the final file name to a list at the deepest level
    current[parts[-1].split('.')[0]] = path

def download_openxlab(path):
    oxlds.download("omniobject3d/OmniObject3D-New", path, './data')
    pathlist = path.strip('/').split('/')
    # Path where it got saved
    nested_path = "./data/omniobject3d___OmniObject3D-New" + path
    
    # Path where you want it
    target_path = "./data/" + pathlist[-1].split('.')[0]

    with tarfile.open(nested_path, 'r:gz') as tar:
        tar.extractall(path=target_path)
        
    # Optional: remove empty folders
    shutil.rmtree("./data/omniobject3d___OmniObject3D-New/")
    print(f'File Downloaded and Extracted to {target_path}')
    
def print_data(mesh):
    print(mesh.point_data)
    print(mesh.cell_data)
    print(mesh.field_data)

def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)

def save_pickle(obj, path):
    path = Path(path)
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
    os.replace(tmp, path)  # atomic replace

overall_start = time.time()

In [2]:
def make_edge_dict(mesh):
    """
    Extract edge data from a mesh from vertex and face lists.
    Returns a dictionary with each edge as a key in the form
    of a sorted tuple, [(v0, v1)] if v0 < v1, [(v1, v0)] if v1 < v0.
    Each key/edge will contain another dictionary with keys:
    [opposite_vertex_0 : int, opposite_vertex_1 : int, 
     weight (w = cot(opposite_angle_0) + cot(opposite_angle_1)) : float]

     Parameters
     ----------
     mesh : pv.PolyData
     
     Returns
     -------
     edge_dict : dictionary
         Contains a dict for each edge [(v0,v1)] with keys opposite_vertex_0,
         opposite_vertex_1, and weight. opposite_vertex_i contains
         float of the angle in radians made by the three vertices
         v0, opposite_vertex_i, v1
    """
    verts = mesh.points
    faces = mesh.faces
    edge_dict = {}
    if faces.ndim == 1:
        faces = faces.reshape(-1, 4)
    
    for tri in faces[:,1:]:
        i, j, k = map(int, tri)
        for (a,b,c) in [(i,j,k),(j,k,i),(k,i,j)]:
            e = tuple(sorted((a,b)))   # e.g. (min(a,b),max(a,b))
            if e not in edge_dict:
                edge_dict[e] = {}
            v0 = verts[a]
            v1 = verts[b]
            v2 = verts[c]
            e0 = v0 - v2
            e1 = v1 - v2
            costheta = np.dot(e0, e1) / (np.linalg.norm(e0) * np.linalg.norm(e1) + 1e-15)
            costheta = np.clip(costheta, -1, 1)
            angle = np.arccos(costheta)
            edge_dict[e][c] = float(angle)     # store the vertex “across” this edge
            if len(edge_dict[e]) == 2:
                angle0 = list(edge_dict[e].items())[0][1]
                w = (np.cos(angle0)/(np.sin(angle0) + 1e-15)) + (np.cos(angle)/(np.sin(angle) + 1e-15))
                edge_dict[e]['weight'] = float(w)
    edges_idx = np.array([(i[0], i[1]) for i in edge_dict], dtype=int)
    weights = np.array([edge_dict[i].get('weight', 0.0) for i in edge_dict], dtype=float)

    return edge_dict, edges_idx, weights

def string_energy_vec(edges_idx, weights, phi):
    # # phi: (N,3) array of vertex coords on the sphere
    # phi = mesh['spherical_param']
    # edges_idx = mesh['edges_idx']
    # weights = mesh['edges_weights']
    i0 = edges_idx[:,0]
    i1 = edges_idx[:,1]
    # shape (E,3) differences
    diffs = phi[i0] - phi[i1]
    # squared‐norm per edge, shape (E,)
    sqnorm = np.einsum('ij,ij->i', diffs, diffs)
    # sum w * sqnorm
    return float(np.dot(weights, sqnorm))

def extract_vertex_face_energies(
    verts_sph: np.ndarray,
    faces_pv: np.ndarray,
    edges_idx: np.ndarray,
    weights: np.ndarray,
    *,
    return_means: bool = True
):
    """
    Compute discrete Dirichlet energy contributions of a spherical parameterization.

    Returns
    -------
    eE : (E,) per-edge energy w_ij * ||phi_i - phi_j||^2
    fE : (F,) per-face sum of incident edge energies (split equally among incident faces)
    vE : (V,) per-vertex sum of incident edge energies (split 0.5 to each endpoint)
    (optionally) fE_mean, vE_mean : means normalized by incidence counts
    """
    # faces as (F,3)
    faces = faces_pv.reshape(-1, 4)[:, 1:] if faces_pv.ndim == 1 else faces_pv
    V = verts_sph.shape[0]
    F = faces.shape[0]
    E = edges_idx.shape[0]

    # 1) per-edge energy
    i0, i1 = edges_idx[:, 0], edges_idx[:, 1]
    diffs = verts_sph[i0] - verts_sph[i1]               # (E,3)
    len2  = np.einsum('ij,ij->i', diffs, diffs)         # (E,)
    eE    = weights * len2                               # (E,)

    # 2) per-vertex accumulation (split evenly to endpoints)
    vE = np.zeros(V, dtype=float)
    np.add.at(vE, i0, 0.5 * eE)
    np.add.at(vE, i1, 0.5 * eE)

    # 3) per-face accumulation
    # Build a light vertex->faces adjacency to find faces adjacent to each edge
    v2f = [[] for _ in range(V)]
    for fi, (a, b, c) in enumerate(faces):
        v2f[a].append(fi); v2f[b].append(fi); v2f[c].append(fi)

    fE = np.zeros(F, dtype=float)
    face_inc_counts = np.zeros(F, dtype=int)

    for k, (a, b) in enumerate(edges_idx):
        adj = list(set(v2f[a]) & set(v2f[b]))  # 1 (boundary) or 2 (interior)
        if len(adj) == 0:
            # Non-manifold or disconnected edge; skip or assign nowhere
            continue
        share = eE[k] / len(adj)
        for fi in adj:
            fE[fi] += share
            face_inc_counts[fi] += 1

    if not return_means:
        return eE, fE, vE

    # 4) Optional: means normalized by incidence (avoid division by zero)
    # For vertices, count how many incident edges per vertex
    vertex_inc_counts = np.zeros(V, dtype=int)
    np.add.at(vertex_inc_counts, i0, 1)
    np.add.at(vertex_inc_counts, i1, 1)

    with np.errstate(divide='ignore', invalid='ignore'):
        vE_mean = np.where(vertex_inc_counts > 0, vE / vertex_inc_counts, 0.0)
        fE_mean = np.where(face_inc_counts   > 0, fE / face_inc_counts,     0.0)

    return eE, fE_mean, vE_mean


def d_energy_vec(edges_idx, weights, phi):
    """
    Vectorized ∂E/∂Phi for all vertices at once.
    
    edges_idx : (E,2) int array of [i,j] pairs
    weights   : (E,)  float array of w_ij
    phi       : (N,3) float array of Phi positions

    returns grad : (N,3) float array of 2*sum_j w_ij*(Phi_i - Phi_j)
    """
    i0 = edges_idx[:, 0]      # shape = (E,)
    i1 = edges_idx[:, 1]      # shape = (E,)

    # 1) compute the per-edge vector differences
    diffs    = phi[i0] - phi[i1]         # (E,3)

    # 2) weight them
    w_diffs  = diffs * weights[:, None]  # (E,3)

    # 3) scatter‐add into a per-vertex accumulator
    grad     = np.zeros_like(phi)        # (N,3)
    np.add.at(grad, i0,  w_diffs)        # grad[i0] +=  w_diffs
    np.add.at(grad, i1, -w_diffs)        # grad[i1] -=  w_diffs

    # 4) if you want the true energy‐gradient, multiply by 2
    # grad *= 2

    return grad

def C2_adaptive(phi0, edges_idx, weights, *,
                tol=1e-10, max_iter=10_000, initial_dt=0.1, verbose=True):
    """
    Projected gradient descent on S^2 to minimize sum w_ij ||phi_i - phi_j||^2.
    Inputs:
      phi0: (N,3) initial directions (will be normalized)
      edges_idx: (E,2) int
      weights: (E,) float
    Returns:
      phi: (N,3) final S^2 embedding
      energies: list of energy values over iterations
    """
    if verbose:
        pbar = tqdm(total=max_iter, desc='Spherical Param Gradient Descent', unit='iter', dynamic_ncols=True)
    # normalize just in case
    phi = phi0 / (np.linalg.norm(phi0, axis=1, keepdims=True) + 1e-15)

    E = string_energy_vec(edges_idx, weights, phi)
    energies = [E]
    dt = float(initial_dt)

    for it in range(1, max_iter+1):
        g = d_energy_vec(edges_idx, weights, phi)
        # project gradient onto tangent space
        g_proj = g - (np.einsum('ij,ij->i', g, phi))[:, None] * phi

        ok = False
        dt_try = dt
        for _ in range(6):
            phi_try = phi - dt_try * g_proj
            phi_try /= (np.linalg.norm(phi_try, axis=1, keepdims=True) + 1e-15)
            E_try = string_energy_vec(edges_idx, weights, phi_try)
            if E_try < E:
                phi, E = phi_try, E_try
                energies.append(E)
                dt = min(dt_try * 1.25, 1.0)
                ok = True
                break
            dt_try *= 0.25

        if not ok or (energies[-2] - energies[-1]) < tol:
            break
        if verbose:
            pbar.update(1)
    return phi, energies

def normalize_mesh(mesh):
    mesh_copy = mesh.copy()
    mesh_copy.points -= mesh_copy.center_of_mass()
    area = mesh_copy.compute_cell_sizes(area=True)['Area'].sum()
    mesh_copy.points /= np.sqrt(area)
    return mesh_copy

In [3]:
def srnf_from_mesh(mesh):
    mesh.compute_normals(point_normals=False, cell_normals=True, inplace=True)
    mesh['Area'] = mesh.compute_cell_sizes(length=False, area=True, volume=False)['Area']
    u = mesh.cell_data["Normals"]                # (F, 3)
    A = mesh.cell_data["Area"]                  # (F,)
    q = u * np.sqrt(2.0 * A)[:, None]           # (F, 3)
    return q

def spherical_triangle_area(a, b, c):
    # a,b,c: (M,3) unit vectors; returns (M,) steradians
    num = np.abs(np.einsum('ij,ij->i', a, np.cross(b, c)))
    den = 1.0 + np.einsum('ij,ij->i', a, b) \
              + np.einsum('ij,ij->i', b, c) \
              + np.einsum('ij,ij->i', c, a)
    return 2.0 * np.arctan2(num, den)

def icosphere_quadrature(nsub=3, radius=1.0):
    ico = pv.Icosphere(nsub=nsub, radius=radius)
    faces = ico.faces
    F = faces.reshape(-1, 4)[:, 1:]         # (M,3)
    P = ico.points                               # (Nv,3) ~ unit vectors
    vertices = P
    tri = P[F]                                   # (M,3,3)
    centers = tri.mean(axis=1)
    centers /= np.linalg.norm(centers, axis=1, keepdims=True)
    areas = spherical_triangle_area(tri[:,0], tri[:,1], tri[:,2]) * (radius**0)  # already on sphere
    # sanity: areas.sum() ≈ 4π * (radius**2)
    return centers, areas, vertices, faces

def general_quadrature(P, F):
    if F.ndim == 1:
        F = F.reshape(-1,4)[:,1:]
    tri = P[F]
    centers = tri.mean(axis=1)
    centers /= np.linalg.norm(centers, axis=1, keepdims=True)
    areas = spherical_triangle_area(tri[:,0], tri[:,1], tri[:,2])
    return centers, areas
    
def get_icosphere_level(A, B, nsub):
    if nsub in A.icosphere:
        return A.icosphere[nsub]['points'], A.icosphere[nsub]['weights'], A.icosphere[nsub]['vertices'], A.icosphere[nsub]['faces']
    if nsub in B.icosphere:
        return B.icosphere[nsub]['points'], B.icosphere[nsub]['weights'], B.icosphere[nsub]['vertices'], B.icosphere[nsub]['faces']
    U, w, v, f = icosphere_quadrature(nsub)  # (M,3), (M,)
    A.icosphere[nsub] = {'points': U, 'weights': w, 'vertices': v, 'faces': f}
    B.icosphere[nsub] = {'points': U, 'weights': w, 'vertices': v, 'faces': f}
    return U, w, v, f

def get_sample_points(A, B):
    # ensure alignments dicts exist
    if getattr(A, "alignments", None) is None:
        A.alignments = {}
    if getattr(B, "alignments", None) is None:
        B.alignments = {}

    # if already cached on B for A, return it
    if A.uid in B.alignments and 'sample_points' in B.alignments[A.uid]:
        return B.alignments[A.uid]['sample_points']

    # otherwise build once
    sample_dict = build_sample_points(A, B)

    # store symmetrically, creating nested dicts if needed
    A.alignments.setdefault(B.uid, {})['sample_points'] = sample_dict
    B.alignments.setdefault(A.uid, {})['sample_points'] = sample_dict

    return sample_dict

    
def l2_q(qA, qB, w):
    d = qA - qB
    return float(np.dot(w, np.einsum('ij,ij->i', d, d)))

# ... keep your helpers as provided above ...
def l2(A: "SquareRootNormalMesh", B: "SquareRootNormalMesh", nsub = 5):
    #sample_points, sample_weights, _, _ = get_icosphere_level(A, B, nsub)
    sample_dict = get_sample_points(A,B)
    sample_points = sample_dict['points']
    sample_weights = sample_dict['weights']
    if B.uid not in A.alignments.keys():
        A.get_alignment(B)
    mobius_parameters = A.alignments[B.uid]['mobius_params']
    R = A.alignments[B.uid]['space_rotation']
    sample_reparam, sqrtJ = mobius_apply_on_sphere(sample_points, mobius_parameters)

    q1 = A.sample(sample_points)
    gamma_q2 = B.sample(sample_reparam)
    R_gamma_q2 = (gamma_q2 @ R.T) * sqrtJ[:, None]
    
    l2_sq = l2_q(q1, R_gamma_q2, sample_weights)

    return l2_sq

def stereographic_project(P):  # P: (N,3) unit vectors
    X, Y, Z = P[:,0], P[:,1], P[:,2]
    denom = 1.0 - Z
    # guard near north pole: denom -> 0 => z -> ∞
    denom = np.where(np.abs(denom) < 1e-15, 1e-15, denom)
    return (X + 1j*Y) / denom  # complex array (N,)

def stereographic_unproject(w):  # w: (N,) complex -> (N,3) unit vectors
    x = np.real(w); y = np.imag(w)
    r2 = x*x + y*y
    denom = 1.0 + r2
    P = np.stack([2*x/denom, 2*y/denom, (r2 - 1.0)/denom], axis=1)
    # numeric safety
    P /= np.linalg.norm(P, axis=1, keepdims=True)
    return P

def mobius_apply_on_sphere(P, mobius_params):
    """
    P: (M,3) unit sphere points
    returns: P_gamma (M,3), sqrtJ (M,)
    """
    a, b, c, d = mobius_params
    z = stereographic_project(P)            # (M,) complex
    czd = c*z + d
    w = (a*z + b) / (czd)               # (M,) complex
    P_gamma = stereographic_unproject(w)    # (M,3)

    # sqrt(J) on the sphere (see derivation in the message)
    num = 1.0 + (z.real**2 + z.imag**2)
    den = 1.0 + (w.real**2 + w.imag**2)
    den2 = np.maximum(np.abs(czd)**2, 1e-15)
    sqrtJ = (num/den) / den2      # (M,)
    return P_gamma, sqrtJ

# ---------------------------
# Low-discrepancy sampling on SO(3)
# (Halton-based; named "sobol" for API compatibility)
# ---------------------------
def _van_der_corput(n, base, skip=0, shift=0.0):
    i = np.arange(skip+1, skip+n+1, dtype=float)
    seq = np.zeros_like(i)
    denom = 1.0
    while np.any(i > 0):
        i, rem = divmod(i, base)
        denom *= base
        seq += rem / denom
    return (seq + shift) % 1.0

def sample_quaternions_sobol(N, skip=0, shifts=(0.0,0.0,0.0)):
    u1 = _van_der_corput(N, 2, skip=skip, shift=shifts[0])
    u2 = _van_der_corput(N, 3, skip=skip, shift=shifts[1])
    u3 = _van_der_corput(N, 5, skip=skip, shift=shifts[2])
    # ... Shoemake as before ...

    theta1 = 2*np.pi*u2
    theta2 = 2*np.pi*u3
    r1 = np.sqrt(1.0 - u1)
    r2 = np.sqrt(u1)
    w = r2*np.cos(theta2)
    x = r1*np.sin(theta1)
    y = r1*np.cos(theta1)
    z = r2*np.sin(theta2)
    Q = np.stack([w,x,y,z], axis=1)
    # numeric guard
    Q /= np.linalg.norm(Q, axis=1, keepdims=True) + 1e-15
    return Q

def su2_from_quat(w, x, y, z):
    """
    Map unit quaternion to SU(2) matrix [[α, β],[-β*, α*]].
    """
    alpha = w + 1j*z
    beta  = y + 1j*x
    s = np.sqrt((alpha*alpha.conjugate()).real + (beta*beta.conjugate()).real)
    alpha /= (s + 1e-15)
    beta  /= (s + 1e-15)
    return np.array([[alpha, beta],
                     [-beta.conjugate(), alpha.conjugate()]], dtype=np.complex128)

# ---------------------------
# Möbius plumbing
# ---------------------------
def compose_similarity_then_rotation(lam, t, Mrot):
    """
    Msim = [[lam, lam*t], [0, 1]], then M = Msim @ Mrot.
    Returns a,b,c,d (complex scalars).
    """
    Msim = np.array([[lam, lam*t],
                     [0+0j, 1+0j]], dtype=np.complex128)
    M = Msim @ Mrot
    a, b = M[0,0], M[0,1]
    c, d = M[1,0], M[1,1]
    # normalize (det ~ 1), optional but helps stability
    det = a*d - b*c
    if np.abs(det) > 1e-15:
        s = np.sqrt(det)
        a /= s; b /= s; c /= s; d /= s
    return a, b, c, d

# ---------------------------
# SRNF alignment helpers
# ---------------------------
def kabsch_srnf(Q1, Q2, weights=None, proper_rotation=True):
    """
    Weighted Kabsch on 3D vectors (rows correspond).
    Returns R (3x3).
    """
    if weights is None:
        w = np.ones(len(Q1), float)
    else:
        w = np.asarray(weights, float)
    w = w / (w.sum() + 1e-15)
    C = Q1.T @ (Q2 * w[:, None])
    U, S, Vt = np.linalg.svd(C, full_matrices=False)
    R = U @ Vt
    if proper_rotation and np.linalg.det(R) < 0:
        U[:, -1] *= -1
        R = U @ Vt
    return R

def face_centroids_on_sphere(mesh_sph: pv.PolyData) -> np.ndarray:
    F = mesh_sph.faces.reshape(-1, 4)[:, 1:]     # (nF, 3) vertex indices
    P = mesh_sph.points                           # (nV, 3)
    C = P[F].mean(axis=1)                         # (nF, 3) Euclidean centroids
    # ensure they lie on S^2 (good for meshes slightly off the unit sphere)
    C /= (np.linalg.norm(C, axis=1, keepdims=True) + 1e-15)
    return C

def build_face_kdtree(mesh_sph: pv.PolyData):
    C = face_centroids_on_sphere(mesh_sph)       # (nF, 3)
    tree = cKDTree(C)
    return tree, C

def wrap_angle(a):
    a = np.mod(a, 2*np.pi)
    return a

def zoom_range(center, half_width, n, is_angle=False):
    if is_angle:
        grid = np.linspace(-half_width, +half_width, n)
        return wrap_angle(center + grid)
    else:
        return np.linspace(center - half_width, center + half_width, n)

def locate_on_mesh_from_sphere(A: "SquareRootNormalMesh", U: np.ndarray):
    """
    U: (M,3) unit directions on S^2.
    Returns:
      P_world: (M,3) interpolated points on A.mesh (original geometry)
      face_ids: (M,) triangle indices on A.mesh_sph
      bary: (M,3) barycentric weights (for optional reuse)
    """
    P_sph = A.mesh_sph.points
    F = A.mesh_sph.faces.reshape(-1,4)[:,1:]
    face_ids = A.preindex(U)                 # (M,)
    tri_idx = F[face_ids]                    # (M,3) vertex ids
    a_s, b_s, c_s = P_sph[tri_idx[:,0]], P_sph[tri_idx[:,1]], P_sph[tri_idx[:,2]]

    # barycentrics in the chord triangle (good for small faces)
    # vectorized computation
    v0 = b_s - a_s
    v1 = c_s - a_s
    v2 = U   - a_s
    d00 = np.einsum('ij,ij->i', v0, v0)
    d01 = np.einsum('ij,ij->i', v0, v1)
    d11 = np.einsum('ij,ij->i', v1, v1)
    d20 = np.einsum('ij,ij->i', v2, v0)
    d21 = np.einsum('ij,ij->i', v2, v1)
    denom = d00 * d11 - d01 * d01 + 1e-15
    w2 = (d00 * d21 - d01 * d20) / denom
    w1 = (d11 * d20 - d01 * d21) / denom
    w0 = 1.0 - w1 - w2
    bary = np.stack([w0, w1, w2], axis=1)    # (M,3)

    # (optional) clamp tiny negatives due to numeric error
    # bary = np.clip(bary, 0.0, 1.0); bary /= bary.sum(axis=1, keepdims=True)

    # interpolate original-geometry point positions
    V_world = A.mesh.points                  # (N,3) original mesh vertices
    pa, pb, pc = V_world[tri_idx[:,0]], V_world[tri_idx[:,1]], V_world[tri_idx[:,2]]
    P_world = (bary[:,[0]] * pa +
               bary[:,[1]] * pb +
               bary[:,[2]] * pc)

    return P_world

#def get_equal_area_sample_points(A: "SquareRootNormalMesh", B: "SquareRootNormalMesh"):

In [4]:
EPS = 1e-15

def unit(X): return X / (np.linalg.norm(X, axis=1, keepdims=True)+EPS)
def build_sample_points(A, B, nsub=5):
    type_error = TypeError("A and B must both be type ndarray or SquareRootNormalMesh!")
    if isinstance(A, np.ndarray):
        A_samp = A
        if isinstance(B, np.ndarray):
            B_samp = B
        elif B is None:
            B_samp = np.empty((0,3), float)
        else:
            raise type_error
    elif hasattr(A, 'mesh_sph'):
        A_samp = general_quadrature(A.mesh_sph.points, A.mesh_sph.faces)[0]
        if hasattr(B, 'mesh_sph'):
            B_samp = general_quadrature(B.mesh_sph.points, B.mesh_sph.faces)[0]
        elif B is None:
            B_samp = np.empty((0,3), float)
        else:
            raise type_error
    else:
        raise type_error
    ico_samp = pv.Icosphere(nsub=nsub).points if nsub>=0 else np.empty((0,3), float)
    samples = np.vstack([A_samp, B_samp, ico_samp])
    def dedupe_on_sphere(samples, eps=1e-10):
        S = samples / np.linalg.norm(samples, axis=1)[:,None]
        tree = cKDTree(S)
        parent = np.arange(len(S))
        def find(a):
            while parent[a] != a:
                parent[a] = parent[parent[a]]
                a = parent[a]
            return a
        def union(a,b):
            ra, rb = find(a), find(b)
            if ra != rb: parent[rb] = ra
        for a,b in tree.query_pairs(eps):
            union(a,b)
        roots = np.array([find(i) for i in range(len(S))])
        uniq_roots, labels = np.unique(roots, return_inverse=True)
        uniq  = np.zeros((len(uniq_roots),3))
        counts = np.zeros(len(uniq_roots), dtype=int)
        for k,r in enumerate(uniq_roots):
            idx = np.where(roots==r)[0]
            uniq[k] = unit(S[idx].mean(axis=0, keepdims=True))[0]
            #uniq = S[idx] / np.linalg.norm(S[idx])
            counts[k] = len(idx)
        return uniq, labels, counts
    
    def sort_region_vertices_ccw(region_idx, vertices, site):
        """Sort polygon vertices CCW around 'site' on S^2 in a tangent frame."""
        pts = vertices[np.asarray(region_idx, dtype=int)]
        pts = unit(pts)
        n = unit(site[None,:])[0]
        # build orthonormal tangent basis (t1 along first edge, t2 = n x t1)
        t1 = unit((pts[0] - n*np.dot(n, pts[0]))[None,:])[0]
        t2 = np.cross(n, t1)
        # angles
        x = pts @ t1
        y = pts @ t2
        ang = np.arctan2(y, x)
        order = np.argsort(ang)
        return np.asarray(region_idx, dtype=int)[order]
    
    def spherical_polygon_area(sorted_poly):
        """Girard on already-ordered polygon."""
        poly = unit(sorted_poly)
        m = len(poly)
        A = 0.0
        for i in range(m):
            u = poly[(i-1)%m]; v = poly[i]; w = poly[(i+1)%m]
            n = v
            def proj(x):
                x = x - n*np.dot(n, x)
                return x / (np.linalg.norm(x)+EPS)
            a = proj(u); b = proj(w)
            num = np.linalg.norm(np.cross(a,b)); den = np.clip(np.dot(a,b), -1.0, 1.0)
            A += np.arctan2(num, den)
        return A - (m-2)*np.pi
    
    def voronoi_weights(samples, eps=1e-10):
        uniq, labels, counts = dedupe_on_sphere(samples, eps=eps)
        print("Sample points dropped:", samples.shape[0] - uniq.shape[0])
        sv = SphericalVoronoi(uniq.astype(np.float64),
                              radius=1.0,
                              center=np.array([0.0,0.0,0.0], dtype=np.float64),
                              threshold=eps*10)
    
        # Manually sort each region CCW in a tangent frame at the generator
        areas_uniq = np.empty(len(uniq))
        for i, reg in enumerate(sv.regions):        
            if len(reg) < 3:
                areas_uniq[i] = 0
                continue
            reg_sorted = sort_region_vertices_ccw(reg, sv.vertices, uniq[i])
            poly = sv.vertices[reg_sorted]
            areas_uniq[i] = spherical_polygon_area(poly)
    
        weights_per_site = areas_uniq / (4*np.pi)
        weights = weights_per_site[labels] / counts[labels]
        weights = np.clip(weights, 0.0, None)
        weights /= weights.sum()
        return weights, areas_uniq, labels, counts, uniq
    out = voronoi_weights(samples)
    return {'points': out[4], 'weights': out[0]*4*np.pi}
    



class GenLogger:
    def __init__(self, base_func, pop_size):
        self.base_func = base_func
        self.pop_size = pop_size
        self.eval_count = 0
        self.failure_count = 0
        self.cur_gen_best = float('inf')
        self.history = []   # list of (gen_idx, best_energy)

    def __call__(self, x):
        val = self.base_func(x)
        self.eval_count += 1
        if val > 999:
            self.failure_count += 1
        if val < self.cur_gen_best:
            self.cur_gen_best = val
        # end of generation?
        if self.eval_count % self.pop_size == 0:
            gen_idx = self.eval_count // self.pop_size
            self.history.append((gen_idx, self.cur_gen_best))
            # reset for next gen
            self.cur_gen_best = float('inf')
        return val

def project_psl2c(a,b,c,d, eps=1e-12):
    det = a*d - b*c
    if not np.isfinite(det.real + det.imag) or abs(det) <= eps:
        return None
    # complex sqrt; +0j forces complex dtype
    k = 1.0 / np.sqrt(det + 0j)
    # optional: guard against insane scaling (rare but helpful)
    if abs(k) > 1e6:   # tune as you like
        return None
    return a*k, b*k, c*k, d*k

def coarse_candidates(A: "SquareRootNormalMesh", B: "SquareRootNormalMesh", nsub=5,
                      param_range_array=None, max_iter=1000, popsize=15, mutation=(0.5,1),
                      recombination=0.7, seed = None, REL_TOL=0.01, JMIN=1e-6, JMAX=1e6,
                      MAXLOGVAR=4):
    """
    Return topK candidates sorted by L2: each item is
    (L2, (qw,qx,qy,qz), s, th, r, phi, a,b,c,d, Rspace)
    """

    #U, w, verts, faces = get_icosphere_level(A, B, nsub)
    sample_dict = build_sample_points(A, B)
    U = sample_dict['points']
    w = sample_dict['weights']
    # initial_icosphere_points = verts.copy()
    # initial_icosphere_faces = faces.reshape(-1,4)[:,1:]
    if param_range_array is None:
        c0 = np.repeat(-1, 8).reshape(-1,1)
        c1 = np.repeat(1, 8).reshape(-1,1)
        param_range_array = np.hstack([c0,c1])

    def robust_calculation(mobius_params):
        U2 = mobius_apply_on_sphere(initial_icosphere_points, mobius_params)[0]          # map VERTICES
        U2 = U2 / (np.linalg.norm(U2, axis=1, keepdims=True) + 1e-15)  # <-- normalize vertices
        C2, A1 = general_quadrature(U2, initial_icosphere_faces)                   # mapped centers + areas
        sqrtJ  = np.sqrt(A1 / (w + 1e-15))                          # w == A0
        return C2, sqrtJ


    FOUR_PI = 4.0*np.pi
    
    def check_mobius(w, sqrtJ):
        J = (sqrtJ**2)
        S = float((w * J).sum())
        J_logvar = np.var(np.log(J + 1e-10))
        return (np.isfinite(S)
                and abs(S - FOUR_PI) <= REL_TOL * FOUR_PI
                and (J > JMIN).all() and (J < JMAX).all()
                and J_logvar < MAXLOGVAR)
        

    if seed is None:
        np.random.seed(np.random.randint(10000))
    else:
        np.random.seed(seed)
    q1 = A.sample(U)
    B_tree, C = build_face_kdtree(B.mesh_sph)

    # [a, ai, b, bi, c, ci, d, di]
    def score(params):
        p = np.array(params, copy=True)
        a = complex(params[0], params[1])
        b = complex(params[2], params[3])
        c = complex(params[4], params[5])
        d = complex(params[6], params[7])
        proj = project_psl2c(a,b,c,d)
        if proj is None:
            return 1e3
        a,b,c,d = proj

        U2, sqrtJ = mobius_apply_on_sphere(U, (a,b,c,d))
        if not check_mobius(w, sqrtJ):
            return 1e3
            # U2, sqrtJ = robust_calculation((a,b,c,d))
            # if not check_mobius(w, sqrtJ):
            #     return 1e3
        _, idx = B_tree.query(U2)
        gamma_q2 = B.q_face[idx] * sqrtJ[:,None]
        R = kabsch_srnf(q1, gamma_q2, w, proper_rotation=True)
        R_gamma_q2 = gamma_q2 @ R.T
        L2 = l2_q(q1, R_gamma_q2, w)
        return L2

    pbar = tqdm(total=max_iter, desc='DE Iteration', unit='gen', dynamic_ncols=True)
    gen = {"i": 0}
    def cb(xk, convergence):
        gen["i"] += 1
        pbar.update(1)

    for i in range(15):
        if ((popsize * 8) / 2**i) < 1:
            sobol_size = 2**i
            break
    scorer = GenLogger(score, sobol_size)

    pool = ThreadPool()
    
    result = differential_evolution(
        scorer,
        bounds=param_range_array,
        strategy="best1bin",
        maxiter=max_iter,
        popsize=popsize,
        mutation=mutation,       # (min, max) → jitter range
        recombination=recombination,         # aka crossover probability
        tol=1e-6,
        seed=seed,
        workers=pool.map,                # parallel evaluation
        updating="deferred",       # better with workers>1
        polish=True,               # local L-BFGS-B polish at the end
        init="sobol",              # Sobol’ or "latinhypercube" are good
        callback=cb
    )
    pool.close();pool.join()
    result.population_energies = scorer.history
    print(f"Total: {scorer.eval_count} Failed: {scorer.failure_count}")
    return result

def compute_alignment(A: "SquareRootNormalMesh", B: "SquareRootNormalMesh", nsub: int=5):
    result = coarse_candidates(A, B, nsub=nsub)

    l2 = result.fun
    params = result.x
    log = result.population_energies
    a = complex(params[0], params[1])
    b = complex(params[2], params[3])
    c = complex(params[4], params[5])
    d = complex(params[6], params[7])

    proj = project_psl2c(a,b,c,d)
    a,b,c,d = proj
    
    #U, w, v, f = get_icosphere_level(A, B, nsub)
    sample_dict = build_sample_points(A, B)
    U = sample_dict['points']
    w = sample_dict['weights']
    U2, sqrtJ = mobius_apply_on_sphere(U, (a,b,c,d))
    q1 = A.sample(U)
    gamma_q2 = B.sample(U2) * sqrtJ[:, None]
    space_rotation = kabsch_srnf(q1, gamma_q2, w, proper_rotation=True)
    robust_l2 = l2_q(q1, (gamma_q2 @ space_rotation.T), w)
    if np.abs(l2 - robust_l2) > 1e-5:
        print("L2 mismatch greater than 1e-5")
    return (a,b,c,d), space_rotation, robust_l2, log

In [5]:
def scalars_polar_hue(U, sat_max=0.9, gamma=0.65, p=2):
    """
    North pole -> white, south pole -> black.
    Brightness V from z (with gamma for contrast).
    Saturation increases toward equator: S = sat_max*(1-|z|^p).
    Hue from longitude for orientation cues.
    """
    x, y, z = U.T
    # HSV components
    H = (np.arctan2(y, x) + np.pi) / (2*np.pi)          # [0,1) longitude
    V = np.clip(0.5*(z + 1.0), 0, 1) ** gamma            # brightness: south→0, north→1
    S = sat_max * (1.0 - np.abs(z)**p)                   # low at poles, high near equator

    # HSV -> RGB (vectorized, standard)
    C = V * S
    h6 = H * 6.0
    i = np.floor(h6).astype(int)
    f = h6 - i
    X = C * (1.0 - np.abs((h6 % 2.0) - 1.0))
    Z = np.zeros_like(C)

    rgbp = np.zeros((len(U), 3), float)
    i_mod = i % 6
    rgbp[i_mod==0] = np.stack([C, X, Z], axis=1)[i_mod==0]
    rgbp[i_mod==1] = np.stack([X, C, Z], axis=1)[i_mod==1]
    rgbp[i_mod==2] = np.stack([Z, C, X], axis=1)[i_mod==2]
    rgbp[i_mod==3] = np.stack([Z, X, C], axis=1)[i_mod==3]
    rgbp[i_mod==4] = np.stack([X, Z, C], axis=1)[i_mod==4]
    rgbp[i_mod==5] = np.stack([C, Z, X], axis=1)[i_mod==5]

    m = (V - C)[:, None]
    rgb = np.clip(rgbp + m, 0, 1)
    return (255*rgb).astype(np.uint8)

def scalars_oct(P):
    octant = (P[:,0] > 0).astype(int)*4 + (P[:,1] > 0).astype(int)*2 + (P[:,2] > 0).astype(int)
    return octant
    # colormap: 'tab10' or 'tab20'

def scalars_fifty(P, seed = None):
    x = (P[:,0] > np.sqrt(3)/3).astype(int)*4 + (P[:,0] > 0).astype(int)*2 + (P[:,0] > -np.sqrt(3)/3).astype(int)
    y = (P[:,1] > np.sqrt(3)/3).astype(int)*32 + (P[:,1] > 0).astype(int)*16 + (P[:,1] > -np.sqrt(3)/3).astype(int)*8
    z = (P[:,2] > np.sqrt(3)/3).astype(int)*256 + (P[:,2] > 0).astype(int)*128 + (P[:,2] > -np.sqrt(3)/3).astype(int)*64
    uni, cats = np.unique(x+y+z, return_inverse=True)
    K = len(uni)
    rng = np.random.default_rng(seed)
    perm = rng.permutation(K)
    return perm[cats]

def plot_registration_one_panel(A: "SquareRootNormalMesh", B: "SquareRootNormalMesh", nsub=5,
                                cmap='tab20', scalars_method='hue', hue=(1, 0.05, 15), mobius_params=None,
                                space_rotation = None, gap=(2,0,0), point_size = 5, scalars_array=None):
    if B.uid not in A.alignments:
        raise BaseException('Mesh B not in Mesh A alignment dict')
    mobius_params = A.alignments[B.uid]['mobius_params'] if mobius_params is None else mobius_params
    space_rotation = A.alignments[B.uid]['space_rotation'] if space_rotation is None else space_rotation
    def _barycentric_weights(p, a, b, c):
        """Planar barycentrics of p in triangle (a,b,c) (all 3D)."""
        v0, v1, v2 = b - a, c - a, p - a
        d00 = np.dot(v0, v0); d01 = np.dot(v0, v1); d11 = np.dot(v1, v1)
        d20 = np.dot(v2, v0); d21 = np.dot(v2, v1)
        denom = d00 * d11 - d01 * d01 + 1e-15
        v = (d11 * d20 - d01 * d21) / denom
        w = (d00 * d21 - d01 * d20) / denom
        u = 1.0 - v - w
        return np.array([u, v, w])
    
    #icosphere_points = pv.Icosphere(nsub=nsub).points
    sample_dict = build_sample_points(A,B)
    icosphere_points = sample_dict['points']
    a_ico_mapping = A.locate(icosphere_points)
    gamma_icosphere_points = mobius_apply_on_sphere(icosphere_points, mobius_params)[0]
    b_ico_mapping = B.locate(gamma_icosphere_points)

    # --- A ---
    a_mesh = A.mesh.copy()
    a_points = a_mesh.points

    a_R = kabsch_srnf(icosphere_points, a_ico_mapping, proper_rotation=True)
    a_points = a_points @ a_R.T
    a_ico_mapping = a_ico_mapping @ a_R.T
    
    scale = np.max(np.linalg.norm(a_points, axis=1))
    a_points *= 2 / scale
    a_ico_mapping *= 2 / scale
    
    a_points -= np.array(gap)
    a_ico_mapping  -= np.array(gap)
    a_mesh.points = a_points

    # --- B ---
    b_mesh = B.mesh.copy()
    b_points = b_mesh.points
    b_points = b_points @ space_rotation.T
    b_ico_mapping = b_ico_mapping @ space_rotation.T
    b_R = kabsch_srnf(icosphere_points, b_ico_mapping, proper_rotation=True)
    b_points = b_points @ b_R
    b_ico_mapping = b_ico_mapping @ b_R

    scale = np.max(np.linalg.norm(b_points, axis=1))
    b_points *= 2 / scale
    b_ico_mapping *= 2 / scale
    
    b_ico_mapping += np.array(gap)
    b_points += np.array(gap)
    b_mesh.points = b_points

    jo = 0
    while True:
        if (np.min(np.linalg.norm(a_points, axis = 1)) < 0.9**(jo // 5)) or (np.min(np.linalg.norm(b_points, axis = 1)) < 0.9**(jo // 5)):
            rotation = transform.Rotation.random().as_matrix()
            a_points += np.array(gap)
            a_points @= rotation.T
            a_points -= np.array(gap)
            b_points -= np.array(gap)
            b_points @= rotation.T
            b_points += np.array(gap)
            a_ico_mapping = (a_ico_mapping + np.array(gap)) @ rotation.T - np.array(gap)
            b_ico_mapping = (b_ico_mapping - np.array(gap)) @ rotation.T + np.array(gap)
            icosphere_points = icosphere_points @ rotation.T
            jo += 1
            if jo % 5 == 0:
                icosphere_points *= 0.9
        else:
            break
        
    if scalars_method == 'hue':
        scalars = scalars_polar_hue(icosphere_points, sat_max=hue[0], gamma=hue[1], p=hue[2])

        p = BackgroundPlotter()
        p.add_mesh(a_mesh, color = 'black')
        p.add_mesh(a_ico_mapping, scalars=scalars, rgb=True, point_size=point_size)

        p.add_mesh(b_mesh, color = 'black')
        p.add_mesh(b_ico_mapping, scalars=scalars, rgb=True, point_size=point_size)

        p.add_mesh(icosphere_points, scalars=scalars, rgb=True, point_size=point_size)
        plot_registration_one_panel.plotter = p
        return 1
    elif scalars_method == 'fifty':
        scalars = scalars_fifty(icosphere_points)
    elif scalars_method == 'oct':
        scalars = scalars_oct(icosphere_points)
    elif scalars_array is None:
        scalars = np.arange(a_mesh.n_cells)
    else:
        scalars = scalars_array
        
    p = BackgroundPlotter()
    p.add_mesh(a_mesh, color = 'black')
    p.add_mesh(a_ico_mapping, cmap=cmap, scalars=scalars, point_size=point_size)
    
    p.add_mesh(b_mesh, color = 'black')
    p.add_mesh(b_ico_mapping, cmap=cmap, scalars=scalars, point_size=point_size)
    
    p.add_mesh(icosphere_points, cmap=cmap, scalars=scalars, point_size=point_size)
    p.show()
    plot_registration_one_panel.plotter = p
    return 1

def gif(A: "SquareRootNormalMesh", B=None, camera_position=None, n_frames=300,
        fps=30, cmap='tab20c', scalars_array=None, scalars_method='hue', point=None, hue=(1, 0.05, 15),
        out_path="overwritten_file.gif", wire=False, spin=False):
    if out_path[:11] != 'animations/':
        out_path = 'animations/' + out_path
    camera_position = [(3.862980654462277, 3.8625836577257506, 3.86289431713432),
                       (0.00011098384857177734, -0.0002860128879547119, 2.4646520614624023e-05),
                       (0.0, 0.0, 1.0)] if camera_position is None else camera_position
    mesh1 = A.mesh
    mesh2 = A.mesh_sph if B is None else B

    def f_exp(t, a):
        return (np.exp(a*t) - 1) / (np.exp(a) - 1)

    def find_a(x0, y0, K=50.0):
        if not (0 < y0 < 1):
            raise ValueError("y0 must lie strictly between 0 and 1")
        if abs(y0 - x0) < 1e-8:
            return 0.0
        def fn(a):
            return f_exp(x0, a) - y0

        a_star = brentq(fn, -K, K)
        return a_star

    a = 1e-8 if point is None else find_a(*point)
    diffeo = lambda t: f_exp(t, a)

    if scalars_method == 'hue':
        scalar_points = general_quadrature(A.mesh_sph.points, A.mesh_sph.faces.reshape(-1,4)[:,1:])[0]
        scalars = scalars_polar_hue(scalar_points, sat_max=hue[0], gamma=hue[1], p=hue[2])
    elif scalars_method == 'fifty':
        scalars = scalars_fifty(mesh1.points)
    elif scalars_method == 'oct':
        scalars = scalars_oct(mesh1.points)
    elif scalars_array is None:
        cmap = "PiYG" if cmap == 'tab20c' else cmap
        scalars = np.arange(mesh1.n_cells)
    else:
        cmap = "PiYG" if cmap == 'tab20c' else cmap
        scalars = scalars_array
    style = 'wireframe' if wire else 'surface'
        
    working = mesh1.copy()
    working['cell_ids'] = scalars

    plotter = pv.Plotter(off_screen=True)
    if scalars_method != 'hue':
        plotter.add_mesh(working,
                         scalars='cell_ids',
                         cmap=cmap,
                         show_scalar_bar=False,
                         smooth_shading=False,
                         style=style)
    else:
        plotter.add_mesh(working,
                         scalars='cell_ids',
                         rgb=True,
                         show_scalar_bar=False,
                         smooth_shading=False,
                         style=style)
    plotter.camera_position = camera_position

    frames = []
    if spin:
        for t in np.linspace(0, 1, n_frames):
            plotter.camera.Azimuth(360/n_frames)
            plotter.render()
            img = plotter.screenshot(return_img=True)
            frames.append(img)

    for t in np.linspace(0, 1, n_frames):
        ft = diffeo(t)
        working.points = (1 - ft) * mesh1.points + ft * mesh2.points
        if spin:
            plotter.camera.Azimuth(360/n_frames)
        plotter.render()
        img = plotter.screenshot(return_img=True)
        frames.append(img)

    file_saved_message = f"Animation Save at {out_path}"
    if out_path[-3:] == 'mp4':
        imageio.mimsave(out_path, frames, format = "FFMPEG", fps = fps, codec = "libx264", quality = 8)
        print(file_saved_message)
    elif out_path[-3:] == 'gif':
        imageio.mimsave(out_path, frames, fps = fps)
        print(file_saved_message)
    else:
        print('Frames returned. Please use valid file type.')
        return frames            

In [6]:
@dataclass
class SquareRootNormalMesh:
    verts: np.ndarray
    verts_sph: np.ndarray
    faces_pv:  np.ndarray
    q_face:    np.ndarray
    name: str = ""
    uid:  str = field(default_factory=lambda: uuid.uuid4().hex)

    mesh: pv.PolyData = field(init=False, repr=False)
    mesh_sph: pv.PolyData = field(init=False, repr=False)
    mesh_initial: pv.PolyData = field(init=False, repr=False)

    edge_dict: dict = field(init=False, repr=False)
    edge_array: np.ndarray = field(init=False, repr=False)
    edge_weights: np.ndarray = field(init=False, repr=False)

    sphere_param_edge_energy: np.ndarray = field(init=False, repr=False)
    sphere_param_face_energy_mean: np.ndarray = field(init=False, repr=False)
    sphere_param_vertex_energy_mean: np.ndarray = field(init=False, repr=False)
    sphere_param_log: list = field(default_factory=list)
    icosphere: dict = field(default_factory=dict)
    alignments: dict = field(default_factory=dict)
    pairwise_l2: dict = field(default_factory=dict)

    def __post_init__(self):
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces_pv)

    @classmethod
    def from_polydata(cls, initial_mesh: pv.PolyData, *, max_iter=5000, name=""):
        # translate and scale so mesh centroid is (0,0,0) and mesh surface area is 1
        pv_mesh = normalize_mesh(initial_mesh)
        
        # build edges/weights
        edge_dict, edge_array, edge_weight = make_edge_dict(pv_mesh)

        # ensure point normals exist and normalize as φ0 ∈ S²
        m = pv_mesh.copy()
        m.compute_normals(point_normals=True, cell_normals=False, inplace=True)
        phi0 = m.point_normals
        phi0 = phi0 / (np.linalg.norm(phi0, axis=1, keepdims=True) + 1e-15)

        # optimize on S²
        verts_sph, logE = C2_adaptive(phi0, edge_array, edge_weight, max_iter=max_iter)
        R = kabsch_srnf(m.points, verts_sph, proper_rotation=False)
        if np.linalg.det(R) < 0:
            print('Inverse Rotation')
        verts_sph = verts_sph @ R.T

        faces_pv = pv_mesh.faces
        q_face = srnf_from_mesh(pv_mesh)

        # construct
        obj = cls(
            verts=pv_mesh.points.copy(),
            verts_sph=verts_sph,
            faces_pv=faces_pv.copy(),
            q_face=q_face,
            name=name
        )
        obj.edge_dict = edge_dict
        obj.edge_array = edge_array
        obj.edge_weights = edge_weight
        obj.mesh = pv_mesh
        obj.mesh_sph = pv.PolyData(verts_sph, faces_pv)
        obj.mesh_initial = initial_mesh

        # energies (NOTE: your function aggregates weights, not energy)
        eE, fE, vE = extract_vertex_face_energies(verts_sph, faces_pv, edge_array, edge_weight)
        obj.sphere_param_edge_energy = eE
        obj.sphere_param_face_energy_mean = fE
        obj.sphere_param_vertex_energy_mean = vE
        obj.sphere_param_log = logE
        return obj

    @classmethod
    def from_voxels(cls, vox: np.ndarray, *, iso=0.5, spacing=(1,1,1), smoothing=None, **kw):
        from skimage.measure import marching_cubes
        v,f,_,_ = marching_cubes(vox, level=iso, spacing=spacing)
        faces_pv = np.c_[np.full((len(f),1),3,np.int32), f.astype(np.int32)].ravel()
        pv_mesh = pv.PolyData(v, faces_pv)
        if smoothing: pv_mesh = pv_mesh.smooth(n_iter=smoothing)
        return cls.from_polydata(pv_mesh, **kw)

    def sample(self, U: np.ndarray) -> np.ndarray:
        idx = self.mesh_sph.find_closest_cell(U)
        return self.q_face[idx]

    def preindex(self, U: np.ndarray) -> np.ndarray:
        return self.mesh_sph.find_closest_cell(U)

    def gather(self, idx: np.ndarray) -> np.ndarray:
        return self.q_face[idx]

    def locate(self, U: np.ndarray) -> np.ndarray:
        return locate_on_mesh_from_sphere(self, U)

    def get_alignment(self, B: "SquareRootNormalMesh", nsub: int = 5, inplace: bool = True) -> None:
        output = compute_alignment(self, B, nsub)
        if not inplace:
            if B.uid in self.alignments:
                return output
            else:
                self.alignments[B.uid] = {'mobius_params': output[0], 'space_rotation': output[1], "logs": output[3]}
                self.pairwise_l2[B.uid] = output[2]
                B.pairwise_l2[B.uid] = output[2]
        else:
            self.alignments[B.uid] = {'mobius_params': output[0], 'space_rotation': output[1], "logs": output[3]}
            self.pairwise_l2[B.uid] = output[2]
            B.pairwise_l2[self.uid] = output[2]

    def plot(self, scalars: np.ndarray = None, cmap: str = "coolwarm", sphere: bool = False) -> None:
        if scalars is None:
            scalars = np.arange(self.verts.shape[0])
        plot_mesh = self.mesh_sph if sphere else self.mesh
        p = BackgroundPlotter()
        p.add_mesh(plot_mesh, cmap = cmap, scalars = scalars)
        p.show()

    def plot_registration(self, B: "SquareRootNormalMesh", nsub: int=7, scalars_method: str='hue',
                            cmap: str='viridis', gap: tuple=(2,0,0), point_size: int=5) -> None:
        if B.uid not in self.alignments:
            raise ValueError("Mesh B is not aligned with mesh A!")

        mobius_params = self.alignments[B.uid]['mobius_params']
        space_rotation = self.alignments[B.uid]['space_rotation'].T

        plot_registration_one_panel(self, B, nsub, cmap, scalars_method, mobius_params=mobius_params,
                                    space_rotation=space_rotation, gap=gap, point_size=point_size)

    def plot_sphere_gif(self, camera_position=None, n_frames=300, fps=30, cmap='tab20c', scalars_array=None,
                       scalars_method='hue', point=None, hue=(1,0.05,15), out_path='overwritten_file.gif',
                       wire=False, spin=False):
        gif(self, camera_position=camera_position, n_frames=n_frames, fps=fps, cmap=cmap, scalars_array=scalars_array,
           scalars_method=scalars_method, point=point, hue=hue, out_path=out_path, wire=wire, spin=spin)

In [7]:
d1_file = "C:/Users/colli/Documents/Thesis/data/d1/d1.nii"
d1_image = nib.load(d1_file)
d1_mat = d1_image.get_fdata()
d1_seg_file = "C:/Users/colli/Documents/Thesis/data/d1/seg/segmented_seg.nii"
d1_seg_image = nib.load(d1_seg_file)
d1_seg_mat = d1_seg_image.get_fdata()

mask_file = "//wsl.localhost/Ubuntu-22.04/home/colli/jo/corpus_callosum_mask.nii.gz"
mask_image = nib.load(mask_file)
mask_resample = nib.processing.resample_from_to(mask_image, d1_seg_image)
mask_mat = mask_resample.get_fdata()
mask_mat[np.where(mask_mat < 0.5)] = 0
mask_mat[np.where(mask_mat >= 0.5)] = 1

verts, faces, normals, values = marching_cubes(mask_mat)
col = np.full((faces.shape[0], 1), 3)
raw_cc_mesh = pv.PolyData(verts, np.hstack((col, faces)).flatten())
cc_mesh = raw_cc_mesh.smooth(n_iter = 500, relaxation_factor = 0.005)

cc_srnm = SquareRootNormalMesh.from_polydata(cc_mesh, max_iter=8000)

def prep_for_param(mesh: pv.PolyData, target_tris=None, seal=True) -> pv.PolyData:
    m = mesh.triangulate().clean().compute_normals(inplace=False)
    if seal:
        m = m.fill_holes(1e6).extract_largest().clean()
    # optional: decimate or subdivide to a comfortable triangle count
    if target_tris is not None:
        if m.n_cells > target_tris:
            frac = 1 - target_tris / m.n_cells
            m = m.decimate_pro(fraction=frac, preserve_topology=True)
        else:
            # Loop subdivision for smoother sampling
            iters = max(1, int(np.ceil(np.log2(target_tris / m.n_cells))))
            m = m.subdivide(nsub=iters, subfilter='loop')
    return m.compute_normals(inplace=False)

shark_mesh = pv.examples.download_great_white_shark()
shark_mesh = prep_for_param(shark_mesh)

shark_srnm = SquareRootNormalMesh.from_polydata(shark_mesh, max_iter=14000)

create_mesh_end = time.time()
goober = create_mesh_end - overall_start
print(f"To make meshes: {goober:.2f}")

Spherical Param Gradient Descent:   0%|                                                     | 0/8000 [00:00<?,…

Inverse Rotation


Spherical Param Gradient Descent:   0%|                                                    | 0/14000 [00:00<?,…

To make meshes: 281.84


In [8]:
def prep_for_param(mesh: pv.PolyData, target_tris=None, seal=True) -> pv.PolyData:
    m = mesh.triangulate().clean().compute_normals(inplace=False)
    if seal:
        m = m.fill_holes(1e6).extract_largest().clean()
    # optional: decimate or subdivide to a comfortable triangle count
    if target_tris is not None:
        if m.n_cells > target_tris:
            frac = 1 - target_tris / m.n_cells
            m = m.decimate_pro(fraction=frac, preserve_topology=True)
        else:
            # Loop subdivision for smoother sampling
            iters = max(1, int(np.ceil(np.log2(target_tris / m.n_cells))))
            m = m.subdivide(nsub=iters, subfilter='loop')
    return m.compute_normals(inplace=False)

shark_mesh = pv.examples.download_great_white_shark()
shark_mesh = prep_for_param(shark_mesh)

#shark_srnm = SquareRootNormalMesh.from_polydata(shark_mesh, max_iter=20000)

cow_mesh = pv.examples.download_cow()
cow_mesh = prep_for_param(cow_mesh)

#cow_srnm = SquareRootNormalMesh.from_polydata(cow_mesh, max_iter=20000)

cow_srnm.get_alignment(shark_srnm)

cow_srnm.plot_registration(shark_srnm)

NameError: name 'cow_srnm' is not defined

In [137]:
def initialize_point(pts, faces, idx0, idx1, og_point, edges=None, n_iter=5000):
    if faces.ndim == 1:
        faces = faces.reshape(-1, 4)[:, 1:]
    if edges is None:
        edges = build_edges(faces)

    # --- 1. Identify Neighbors (Same as your code) ---
    idx_pair = np.array([idx0, idx1])
    # Find faces connected to the pair
    rel_face_mask = np.any(np.isin(faces, idx_pair), axis=1)
    rel_faces = faces[rel_face_mask]
    
    # Calculate radius based on ring (Same as your code)
    ring = edges[np.any(np.isin(edges, idx_pair), axis=1)].flatten()
    ring = np.unique(ring[~np.isin(ring, idx_pair)])
    maxy = np.max(np.linalg.norm(og_point - pts[ring], axis=1))

    # --- 2. Orientation Logic (Kept your logic, but risky!) ---
    other_faces = faces[~rel_face_mask]
    centers = np.mean(pts[other_faces], axis=1)
    cross_other = get_cross(pts, other_faces) 
    # Normalized dot product check
    dots = np.einsum('ij,ij->i', centers, cross_other) # assuming get_cross is unnormalized
    if np.mean(dots) < 0:
        rel_faces = rel_faces[:, [0, 2, 1]] # Flip winding order

    # --- 3. Vectorized Sampling ---
    # Generate all random parameters at once
    r0 = maxy * np.sqrt(np.random.uniform(0, 1, n_iter))
    r1 = maxy * np.sqrt(np.random.uniform(0, 1, n_iter))
    phi0 = np.random.uniform(0, 2 * np.pi, n_iter)
    phi1 = np.random.uniform(0, 2 * np.pi, n_iter)

    # Tangent Basis Construction
    o = og_point / (np.linalg.norm(og_point) + 1e-10)
    a = np.array([1.0, 0.0, 0.0]) if abs(o[0]) < 0.9 else np.array([0.0, 1.0, 0.0])
    e0 = np.cross(o, a); e0 /= np.linalg.norm(e0)
    e1 = np.cross(o, e0)

    # Create Sample Offsets (n_iter, 3)
    samp0 = r0[:, None] * (np.cos(phi0)[:, None] * e0 + np.sin(phi0)[:, None] * e1)
    samp1 = r1[:, None] * (np.cos(phi1)[:, None] * e0 + np.sin(phi1)[:, None] * e1)

    # Candidate positions on sphere
    cands0 = og_point + samp0
    cands0 /= np.linalg.norm(cands0, axis=1, keepdims=True)
    cands1 = og_point + samp1
    cands1 /= np.linalg.norm(cands1, axis=1, keepdims=True)

    # --- 4. Vectorized Evaluation ---
    # We need to compute metrics for 'rel_faces' for all 5000 candidates.
    # We create a tensor of points: (n_iter, n_rel_faces, 3, 3) 
    # or just gather indices.
    
    # Map global indices to 0, 1 (for idx0, idx1) and 2..N (for static neighbors)
    # This part is tricky to vectorize without fancy indexing, 
    # but here is the cleanest way:
    
    # Create a copy of pts for the static geometry
    local_pts = pts.copy() 
    
    # We need to construct the triangle vertices for all candidates.
    # Shape: (n_iter, n_rel_faces, 3 coordinates, 3 xyz)
    # This requires constructing a gathered array where we swap in cands0 and cands1
    
    # Fast approach: Look at rel_faces. 
    # Replace idx0 with a special flag -1, idx1 with -2
    mapped_faces = rel_faces.copy()
    mask0 = (mapped_faces == idx0)
    mask1 = (mapped_faces == idx1)
    
    # Prepare a tensor of shape (n_iter, n_rel_faces, 3, 3)
    # Fill with static points first
    # (Note: This is memory intensive. If faces > 1000, do batches. 
    # But usually rel_faces is small ~10 faces).
    num_rel = len(rel_faces)
    
    # Get static coords for all vertices in relative faces
    base_coords = pts[rel_faces] # (n_rel, 3, 3)
    batch_coords = np.tile(base_coords[None, ...], (n_iter, 1, 1, 1)) # (n_iter, n_rel, 3, 3)
    
    # Inject candidate 0
    # mask0 is (n_rel, 3) -> expand to (n_iter, n_rel, 3)
    batch_coords[:, mask0, :] = cands0.repeat(np.sum(mask0), axis=0).reshape(n_iter, -1, 3)
    
    # Inject candidate 1
    batch_coords[:, mask1, :] = cands1.repeat(np.sum(mask1), axis=0).reshape(n_iter, -1, 3)

    # --- 5. Compute Metrics on Tensor ---
    # v0, v1, v2 for cross product
    v0 = batch_coords[:, :, 0, :]
    v1 = batch_coords[:, :, 1, :]
    v2 = batch_coords[:, :, 2, :]
    
    # Cross product (normals * 2*area)
    # Shape: (n_iter, n_rel, 3)
    cross_vecs = np.cross(v1 - v0, v2 - v0)
    cross_norms = np.linalg.norm(cross_vecs, axis=2) + 1e-10
    
    # Centers
    tri_centers = np.mean(batch_coords, axis=2)
    tri_centers_norm = np.linalg.norm(tri_centers, axis=2) + 1e-10
    
    # Dots (Outward check)
    # (n_iter, n_rel)
    dot_vals = np.einsum('ijk,ijk->ij', tri_centers/tri_centers_norm[...,None], cross_vecs/cross_norms[...,None])
    
    mean_dots = np.mean(dot_vals, axis=1) # (n_iter,)
    
    # CV Area
    cv_areas = np.std(cross_norms, axis=1) / np.mean(cross_norms, axis=1) # (n_iter,)
    
    # Final Metric
    final_metrics = mean_dots - (cv_areas / 3)
    
    best_idx = np.argmax(final_metrics)
    
    return {idx0: cands0[best_idx], idx1: cands1[best_idx]}

def append_three_col(F):
    return np.hstack([np.repeat(3, F.shape[0])[:,None], F])

def remove_unused_points(P, F):
    if F.ndim == 1:
        F = F.reshape(-1,4)[:,1:]
    unico = np.unique(F)
    new_P = P[unico]
    new_F = np.searchsorted(unico, F)
    return new_P, new_F, unico
    
def ordered_one_ring(v, faces):
    faces_v = faces[np.any(faces == v, axis=1)]
    n, p = faces_v.shape
    cols = np.arange(p)
    rows = np.arange(n)[:, None]
    
    pos = np.where(faces_v == v)[1]      # original positions of 100
    target = 2                        # column where you want 100 to end up
    
    shifts = (target - pos) % p       # per-row shift amounts
    col_shift = (cols - shifts[:,None]) % p
    
    out = faces_v[rows, col_shift]

    out2 = np.zeros((n,p), dtype=faces.dtype)
    out2[0] = out[0]
    col1 = out[:,0]
    for i in range(1,n):
        x = out2[i-1, 1]
        idx = np.where(col1 == x)[0]
        if len(idx) != 1:
            print('Bad Connectivity')
            return np.unique(out[:,:2])
        out2[i] = out[idx[0]]
    return out2[:,0]
    
def split_vertex_insert(current_faces, insert_tuple, replace_tuple, S=None, return_replaced_idx=False):
    """
    Insert two deleted faces back (by row), rewire vertex ids using recorded (rows, cols),
    then compute ordered rings for the two restored vertices.
    Optionally place them on S² using a normalized centroid.
    """
    # --- 1) Insert the two faces back by ROW (axis=0), descending index first
    i0, f0 = insert_tuple[0]    # row indices to insert at
    i1, f1 = insert_tuple[1]    # the two 3-idx faces to insert
    if i0 > i1:
        i0, i1, f0, f1 = i1, i0, f1, f0  # insert high row first
    current_faces = np.insert(current_faces, i0, f0, axis=0)
    current_faces = np.insert(current_faces, i1, f1, axis=0)

    # --- 2) Rewire recorded (row, col) positions back to original vertex ids
    # replace_tuple = [ ( (rows0, cols0), vid0 ), ( (rows1, cols1), vid1 ) ]
    (rows0, cols0), vid0 = replace_tuple[0]
    (rows1, cols1), vid1 = replace_tuple[1]
    if return_replaced_idx:
        uniq = np.unique(np.concat([current_faces[rows0, cols0], current_faces[rows1, cols1]]))
        replaced_idx = [i for i in uniq if (i != vid0) and (i != vid1)][0]
    current_faces[rows0, cols0] = vid0
    current_faces[rows1, cols1] = vid1

    # --- 3) Build ordered 1-rings for the two restored vertices
    ring0 = ordered_one_ring(vid0, current_faces)
    ring1 = ordered_one_ring(vid1, current_faces)

    # --- 4) Optional: initialize spherical positions for vid0/vid1
    if S is not None:
        S[vid0] = spherical_kernel_centroid(ring0, S)
        S[vid1] = spherical_kernel_centroid(ring1, S)

    if vid1 not in ring0:
        ring0 = np.append(ring0, vid1)
    if vid0 not in ring1:
        ring1 = np.append(ring1, vid0)

    if return_replaced_idx:
        return current_faces, ring0, ring1, replaced_idx
    else:
        return current_faces, ring0, ring1
        
def get_cross(points, faces):
    if faces.ndim == 1:
        faces = faces.reshape(-1,4)[:,1:]
    e0 = points[faces[:,1]] - points[faces[:,0]]
    e1 = points[faces[:,2]] - points[faces[:,0]]
    cross = np.cross(e0, e1)
    return cross
    
def build_edges(faces):
    if faces.ndim == 1:
        faces = faces.reshape(-1, 4)[:, 1:]
    edges = np.vstack([
        faces[:, [0, 1]],
        faces[:, [1, 2]],
        faces[:, [2, 0]],
    ])
    edges = np.unique(np.sort(edges, axis=1), axis=0)
    return edges

def cotan_weights_for_edge_list(
    edges_idx: np.ndarray,
    points: np.ndarray,
    faces: np.ndarray,
    *,
    boundary_weight: float = 0.0,
) -> np.ndarray:
    """
    Compute cotangent weights for a provided edge list (E,2), returning weights
    in the SAME order as edges_idx.

    Weight for an interior edge (u,v):
        w_uv = cot(angle at opp0) + cot(angle at opp1)
    where opp0/opp1 are the two vertices opposite (u,v) in the two incident triangles.

    Parameters
    ----------
    edges_idx : (E,2) int array
        Edge endpoints. Can be directed or undirected; treated as undirected.
    points : (V,3) float array
        Vertex positions in R^3 (the geometry you measure angles on).
    faces : (F,3) or (F,4) or (4F,) array
        Triangles. If PyVista style, may be (F,4) with leading 3 or flat (4F,) array.
    boundary_weight : float
        Weight assigned to boundary/non-manifold edges that do not have exactly 2 incident triangles.
        Common choices: 0.0 (default), or cot(single angle) if you want 1-sided weights.

    Returns
    -------
    weights : (E,) float array
        Cotan weights aligned with edges_idx row order.
    """
    edges_idx = np.asarray(edges_idx)
    points = np.asarray(points)
    faces = np.asarray(faces)

    # ---- normalize faces to (F,3) int array ----
    if faces.ndim == 1:
        # PyVista flat format: [3,i,j,k, 3,i,j,k, ...]
        faces = faces.reshape(-1, 4)
    if faces.shape[1] == 4:
        # PyVista (F,4): [3,i,j,k]
        faces = faces[:, 1:]
    faces = faces.astype(np.int64, copy=False)

    # ---- canonicalize edges and build lookup to provided order ----
    edges_can = np.sort(edges_idx.astype(np.int64, copy=False), axis=1)
    edge_to_pos = {tuple(e): i for i, e in enumerate(map(tuple, edges_can))}

    E = edges_idx.shape[0]
    cot_sum = np.zeros(E, dtype=np.float64)
    cot_cnt = np.zeros(E, dtype=np.int8)

    # ---- helper: cot(angle at vertex c) in triangle (a,b,c), opposite edge (a,b) ----
    def cot_angle_at_vertex(c: int, a: int, b: int) -> float:
        v0 = points[a] - points[c]
        v1 = points[b] - points[c]
        # cot(theta) = dot(v0,v1) / ||cross(v0,v1)||
        denom = np.linalg.norm(np.cross(v0, v1)) + 1e-15
        return float(np.dot(v0, v1) / denom)

    # ---- single pass over faces: accumulate cotangents into edges present in edges_idx ----
    for (i, j, k) in faces:
        # edge (i,j) opposite k
        e = (i, j) if i < j else (j, i)
        pos = edge_to_pos.get(e)
        if pos is not None:
            cot_sum[pos] += cot_angle_at_vertex(k, i, j)
            cot_cnt[pos] += 1

        # edge (j,k) opposite i
        e = (j, k) if j < k else (k, j)
        pos = edge_to_pos.get(e)
        if pos is not None:
            cot_sum[pos] += cot_angle_at_vertex(i, j, k)
            cot_cnt[pos] += 1

        # edge (k,i) opposite j
        e = (k, i) if k < i else (i, k)
        pos = edge_to_pos.get(e)
        if pos is not None:
            cot_sum[pos] += cot_angle_at_vertex(j, k, i)
            cot_cnt[pos] += 1

    # ---- finalize ----
    # For a closed manifold triangle mesh, every edge should have cot_cnt==2.
    weights = np.full(E, float(boundary_weight), dtype=np.float64)
    interior = (cot_cnt == 2)
    weights[interior] = cot_sum[interior]

    return weights


def progressive_mesh(mesh, min_faces=4, randomization_percentage=0.95, expire_time=100):
    # faces as (F,3) int array
    current_faces = mesh.faces.reshape(-1, 4)[:, 1:]
    current_points = mesh.points.copy()

    # preallocate some extra room for new vertices
    current_points = np.vstack(
        [current_points, np.zeros((current_points.shape[0], 3))]
    )
    next_vid = mesh.n_points  # index for next new vertex

    prog_face_del = []
    prog_idx_replace = []
    prog_weight = []
    

    start = time.time()
    init_bool = True
    while True:
        if time.time() - start > expire_time:
            print("expired...")
            break

        # build undirected edge list from current_faces
        edges = np.vstack(
            [
                current_faces[:, [0, 1]],
                current_faces[:, [1, 2]],
                current_faces[:, [0, 2]],
            ]
        )
        edges = np.unique(np.sort(edges, axis=1), axis=0)
        edge_array = edges

        edge_lens = np.linalg.norm(
            current_points[edge_array[:, 0]] - current_points[edge_array[:, 1]],
            axis=1,
        )
        n_edge_remaining = len(edge_lens)
        if init_bool:
            edge_to_idx = {tuple(e): i for i, e in enumerate(edge_array)}
            edge_dict, _, _ = make_edge_dict(mesh)
            current_weights = np.repeat(-999, 4*n_edge_remaining).astype(float)
            n_total_edges = n_edge_remaining
            for e in edge_dict:
                idx = edge_to_idx[e]
                weight = edge_dict[e]['weight']
                current_weights[idx] = weight
            init_edges = n_edge_remaining
            init_bool = False
            pbar = tqdm(total=n_edge_remaining//3, desc='Edge Reduction', dynamic_ncols=True)

        # stop if already super coarse
        if current_faces.shape[0] <= min_faces or n_edge_remaining <= 6:
            print("hit coarse stop:", current_faces.shape[0], "faces")
            break

        edge_del_idx = np.argsort(edge_lens)
        np.random.shuffle(edge_del_idx[:(n_edge_remaining // 5)]) if (n_edge_remaining / init_edges) > randomization_percentage else None
        

        accepted = False
        vidx_0 = vidx_1 = None
        fidx_del = None

        for idx in edge_del_idx:
            v0 = edge_array[idx, 0]
            v1 = edge_array[idx, 1]

            # simulate collapse v0,v1 -> next_vid
            mask_edge_faces = (
                np.sum(np.isin(current_faces, [v0, v1]), axis=1) == 2
            )
            fidx_del = np.where(mask_edge_faces)[0]
            face_del_temp = current_faces[fidx_del]
            faces_temp = np.delete(current_faces, fidx_del, axis=0)
            faces_temp[np.isin(faces_temp, [v0, v1])] = next_vid

            # 1) no degenerate triangles: 3 distinct verts per face
            a, b, c = faces_temp.T
            good_tri = (a != b) & (b != c) & (a != c)
            if not np.all(good_tri):
                continue

            # 2) no duplicate faces (up to permutation)
            ordered = np.sort(faces_temp, axis=1)
            if np.unique(ordered, axis=0).shape[0] != faces_temp.shape[0]:
                continue

            # 3) Euler characteristic stays 2 (genus-0)
            V = np.unique(faces_temp).size
            e_sim = np.vstack(
                [
                    faces_temp[:, [0, 1]],
                    faces_temp[:, [1, 2]],
                    faces_temp[:, [0, 2]],
                ]
            )
            e_sim = np.unique(np.sort(e_sim, axis=1), axis=0)
            E = e_sim.shape[0]
            F = faces_temp.shape[0]
            chi = V - E + F
            if chi != 2:
                continue

            # if we get here, collapse is accepted
            prog_face_del.append(((fidx_del[0], face_del_temp[0]), (fidx_del[1], face_del_temp[1])))
            vidx_0, vidx_1 = v0, v1
            replace_idx_0 = np.where(current_faces == vidx_0)
            replace_idx_1 = np.where(current_faces == vidx_1)
            prog_idx_replace.append(((replace_idx_0, vidx_0), (replace_idx_1, vidx_1)))
            old_faces = current_faces[np.any(np.isin(current_faces, [vidx_0, vidx_1]), axis=1)]
            current_faces = faces_temp
            accepted = True
            pbar.update(1)
            break

        if not accepted:
            print("no valid collapse found; stopping")
            break

        # create the new vertex position (midpoint of v0,v1)
        mid = (current_points[vidx_0] + current_points[vidx_1]) / 2
        current_points[next_vid] = mid
        

        rel_faces = current_faces[np.any(current_faces == next_vid, axis=1)]
        rel_edges = np.vstack([rel_faces[:,[0,1]], rel_faces[:,[0,2]], rel_faces[:,[1,2]]])
        rel_edges = np.unique(np.sort(rel_edges, axis=1), axis=0)
        new_weights = cotan_weights_for_edge_list(rel_edges, current_points, current_faces)
        weight_change_idx = []
        prev_weights = []
        for i in range(len(new_weights)):
            edge_tuple = tuple(rel_edges[i])
            weight = new_weights[i]
            if edge_tuple not in edge_to_idx:
                widx = n_total_edges
                n_total_edges += 1
                if n_total_edges == len(current_weights):
                    current_weights = np.hstack([current_weights, np.repeat(-999, 1000)])
                edge_to_idx[edge_tuple] = widx
            else:
                widx = edge_to_idx[edge_tuple]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = weight
            weight_change_idx.append(widx)

        old_edges = np.vstack([old_faces[:,[0,1]], old_faces[:,[0,2]], old_faces[:,[1,2]]])
        old_edges = np.unique(np.sort(old_edges, axis=1), axis=0)
        old_edges = old_edges[np.any(np.isin(old_edges, [vidx_0, vidx_1]), axis=1)]
        old_edges_idx = []
        for e in old_edges:
            widx = edge_to_idx[tuple(e)]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = 0
            old_edges_idx.append(widx)
        undo = (weight_change_idx + old_edges_idx, prev_weights)
        prog_weight.append(undo)
        # prog_weight.append(((np.array(weight_change_idx), new_weights), (np.array(old_edges_idx), np.zeros(len(old_edges_idx)))))
        next_vid += 1

        # expand point buffer if needed
        if next_vid == current_points.shape[0]:
            current_points = np.vstack(
                [current_points, np.zeros((50, 3))]
            )

    zeros = np.all(current_points == 0, axis=1)
    if np.any(zeros):
        current_points = current_points[:np.argmax(zeros)]
    if np.any(current_weights == -999):
        current_weights = current_weights[:np.argmin(current_weights)]
    print(f"lifetime edges: {len(current_weights)}!")
    return current_points, current_faces, prog_idx_replace[::-1], prog_face_del[::-1], edge_to_idx, current_weights, prog_weight[::-1]

def repulsion_force(points, edges=None, edge_cancel=True):
    """
    points: (N,3) on S^2
    edges:  (E,2) int or None
    returns (N,3) total repulsion per vertex
    """
    N = points.shape[0]
    diff = points[:, None, :] - points[None, :, :]          # (N,N,3)
    r2 = np.sum(diff * diff, axis=2, keepdims=True) + 1e-8  # (N,N,1)
    np.fill_diagonal(r2[:, :, 0], np.inf)
    F = diff / r2                                           # 1/r^2 repulsion

    if edge_cancel and edges is not None:
        F[edges[:, 0], edges[:, 1]] = 0.0
        F[edges[:, 1], edges[:, 0]] = 0.0

    return F.sum(axis=1)                                    # (N,3)

def check_spherical_flips(phi, faces):
    """
    Returns True if ANY face has flipped or collapsed (signed volume <= 0).
    Uses the scalar triple product: (p0 x p1) . p2
    """
    p0 = phi[faces[:, 0]]
    p1 = phi[faces[:, 1]]
    p2 = phi[faces[:, 2]]
    
    # Vectorized Triple Product
    # Cross product of first two edges
    cross = np.cross(p0, p1)
    
    # Dot product with third vertex
    # Result is ~6x the signed volume of the tetrahedron formed with the origin
    volumes = np.einsum('ij,ij->i', cross, p2)
    
    # If any volume is <= epsilon, it's degenerate or flipped
    # Using 1e-12 as a safety margin against float errors
    if np.any(volumes < 1e-12):
        return True
    return False

def C2_adaptive(phi0, edges_idx, weights, faces, mask=None, *,
                tol=1e-10, max_iter=10000, initial_dt=0.1, verbose=True, 
                repulsion=False):
    """
    Drop-in replacement with FACE FLIP CHECKING.
    Added 'faces' argument to signature.
    """
    if verbose:
        pbar = tqdm(total=max_iter, desc='Spherical Param', unit='iter', dynamic_ncols=True)
        
    phi = phi0.copy()
    
    # Normalize initial active points
    if mask is not None:
        phi[mask] /= (np.linalg.norm(phi[mask], axis=1, keepdims=True) + 1e-15)
    else:
        phi /= (np.linalg.norm(phi, axis=1, keepdims=True) + 1e-15)

    E = string_energy_vec(edges_idx, weights, phi)
    energies = [E]
    dt = float(initial_dt)

    # Auto-balance repulsion
    if repulsion:
        n_active = np.sum(mask) if mask is not None else len(phi)
        repulsion_weight = 50.0 / (n_active if n_active > 0 else 1.0)
        
        # Pre-calculate inverse mask for safety clamp
        if mask is not None:
            inv_mask = ~mask

    for it in range(1, max_iter+1):
        g = d_energy_vec(edges_idx, weights, phi)
        
        if repulsion:
            # Calculate repulsion only on ACTIVE points
            # (Passing mask to repulsion_force if you optimized it, 
            #  or just slicing like this if using your original function)
            rf_active = repulsion_force(phi[mask])
            
            # SUBTRACT repulsion (Gradient Descent moves opposite to g)
            g[mask] -= (rf_active * repulsion_weight)
            
        # Project gradient
        g_proj = g - (np.einsum('ij,ij->i', g, phi))[:, None] * phi

        ok = False
        dt_try = dt
        
        # Line Search
        for _ in range(8): # Increased tries slightly to find valid non-flipping steps
            phi_try = phi - dt_try * g_proj
            
            # --- SAFETY CLAMP & CENTERING ---
            if mask is not None:
                phi_try[inv_mask] = 0.0 # Force ghosts dead
                
                # Center active points
                center = np.sum(phi_try, axis=0) / np.sum(mask)
                phi_try[mask] -= center
                
                # Normalize active points
                norms = np.linalg.norm(phi_try, axis=1)
                safe_norm_mask = norms > 1e-15
                phi_try[safe_norm_mask] /= norms[safe_norm_mask][:,None]
            else:
                phi_try -= np.mean(phi_try, axis=0)
                phi_try /= (np.linalg.norm(phi_try, axis=1, keepdims=True) + 1e-15)
            # --------------------------------

            # 1. Check Energy
            E_try = string_energy_vec(edges_idx, weights, phi_try)
            
            if E_try < E:
                # 2. CRITICAL: Check for Flipped Faces
                # We only check if the energy looks good.
                # If energy improves but faces flip, we MUST REJECT.
                if check_spherical_flips(phi_try, faces):
                    # Flipped! Reject this step size.
                    dt_try *= 0.5
                    continue
                
                # If we get here: Energy is lower AND no faces flipped. Accept.
                phi, E = phi_try, E_try
                energies.append(E)
                dt = min(dt_try * 1.25, 1.0) # Grow step size carefully
                ok = True
                break
                
            dt_try *= 0.5 # Energy didn't improve, shrink step

        if not ok or (len(energies) > 1 and abs(energies[-2] - energies[-1]) < tol):
            break
        if verbose:
            pbar.update(1)
            
    return phi, energies

In [10]:
cow, horse, shark2, shark = load_pickle('cow_horse_shark2_shark_processed_meshes')
boio3 = progressive_mesh(horse, expire_time=1000)

Edge Reduction:   0%|                                                                         | 0/3391 [00:00<…

hit coarse stop: 4 faces
lifetime edges: 34203!


In [12]:
n_step = 30
c2_iter= 3000
step_per_plot = 1
repul = True

# ---------- main visualization ----------

points, faces_i, prog_point, prog_face, edge_to_idx, weights_i, prog_edge = boio3
full_edge_array = np.array(list(edge_to_idx.keys()))
# rescale horse just so it fits nicely
points = points.copy()
points *= 2 / np.max(np.max(points, axis=0) - np.min(points, axis=0))
points -= (np.max(points, axis=0) - 1)

width = 2.0
height = 2.0

# spherical positions (full N, but many unused initially)
spho = np.zeros_like(points)
# start coarse sphere as tetrahedron (pyvista platonic solid: 0=tetra)
spho[np.unique(faces_i)] = pv.PlatonicSolid(0).points
active_point_mask = np.zeros(spho.shape[0], dtype=bool)

p = BackgroundPlotter()
h = 0
ho = np.array([0.0, width * 1.1, 0.0])
v = 0
vo = np.array([0.0, 0.0, height * 1.1])

for step in range(n_step):
    propo = prog_point[step]
    profa = prog_face[step]
    proed = prog_edge[step]

    faces_i, r0, r1, repo = split_vertex_insert(
        faces_i, profa, propo, return_replaced_idx=True
    )

    weight_idx = np.array(proed[0])
    weight_val = np.array(proed[1])
    weight_val[weight_val < -998] = 0
    weights_i[weight_idx] = weight_val

    # 2) horse mesh at this resolution (green)
    hpnts, hfaces, _ = remove_unused_points(points, faces_i)
    if step % step_per_plot == 0:
        horse_mesh = pv.PolyData(hpnts, append_three_col(hfaces))
        horse_mesh.points += h*ho + v*vo
        p.add_mesh(horse_mesh, color='seagreen', show_edges=True, line_width=1)

    # 3) spherical mesh BEFORE relaxation (pink)
    spnts, sfaces, global_idx = remove_unused_points(spho, faces_i)
    propo_0_rep_idx = np.where(global_idx == propo[0][1])[0][0]
    propo_1_rep_idx = np.where(global_idx == propo[1][1])[0][0]
    og_point = spho[repo]
    init = initialize_point(spnts, sfaces, propo_0_rep_idx, propo_1_rep_idx, og_point, n_iter=50000)
    spnts[propo_0_rep_idx] = init[propo_0_rep_idx]
    spnts[propo_1_rep_idx] = init[propo_1_rep_idx]

    if step % step_per_plot == 0:
        spho_before = pv.PolyData(spnts + h*ho + v*vo + np.array([-4.0, 0, 0]),
                                  append_three_col(sfaces))
        p.add_mesh(spho_before, color='lightcoral', show_edges=True, line_width=1)


    spho[global_idx] = spnts
    active_mask = weights_i > 0
    active_weights_i = weights_i[active_mask]
    active_edge_array_i = full_edge_array[active_mask]
    
    active_point_mask[global_idx] = True
    
    spho, energ = C2_adaptive(spho, active_edge_array_i, active_weights_i, faces_i, max_iter=c2_iter, verbose=False, repulsion=repul, mask=active_point_mask)

    active_point_mask[global_idx] = False
    
    # write back updated spherical points
    # spho[global_idx] = pts

    if step % step_per_plot == 0:
        # grid layout
        h += 1
        if h % 5 == 0:
            h = 0
            v += 1

# TESTING OTHER PROGRESSIVE MESH METHODS

In [None]:
def rebuild_bro(prog_output, n_max=None, n=25, n_row=5, ppcolor='seagreen'):
    points, faces_i, prog_point, prog_face, edge_to_idx, weights_i, prog_edge = prog_output
    n_max = len(prog_point) if n_max is None else n_max
    n_plot = np.linspace(0, n_max, n).astype(int)

    points = points.copy()
    points *= 2 / np.max(np.max(points, axis=0) - np.min(points, axis=0))
    points -= (np.max(points, axis=0) - 1)

    width = 2.0
    height = 2.0

    p = BackgroundPlotter()
    h = 0
    ho = np.array([0.0, width*1.1, 0.0])
    v = 0
    vo = np.array([0.0, 0.0, height * 1.1])

    for step in range(len(prog_point)):
        propo = prog_point[step]
        profa = prog_face[step]
        proed = prog_edge[step]
    
        faces_i, r0, r1, repo = split_vertex_insert(
            faces_i, profa, propo, return_replaced_idx=True
        )
        if step in n_plot:
            hpnts, hfaces, _ = remove_unused_points(points, faces_i)
            horse_mesh = pv.PolyData(hpnts, append_three_col(hfaces))
            horse_mesh.points += h*ho + v*vo
            p.add_mesh(horse_mesh, color=ppcolor, show_edges=True, line_width=1)
            h += 1
            if h % n_row == 0:
                h = 0
                v += 1
    return 'poop'

def progressive_mesh_deg(mesh, min_faces=4, randomization_percentage=0.95, expire_time=100):
    centy = mesh.center_of_mass()
    # faces as (F,3) int array
    current_faces = mesh.faces.reshape(-1, 4)[:, 1:]
    current_points = mesh.points.copy()

    # preallocate some extra room for new vertices
    current_points = np.vstack(
        [current_points, np.zeros((current_points.shape[0], 3))]
    )
    next_vid = mesh.n_points  # index for next new vertex

    prog_face_del = []
    prog_idx_replace = []
    prog_weight = []
    

    start = time.time()
    init_bool = True
    while True:
        if time.time() - start > expire_time:
            print("expired...")
            break

        # build undirected edge list from current_faces
        edges = np.vstack(
            [
                current_faces[:, [0, 1]],
                current_faces[:, [1, 2]],
                current_faces[:, [0, 2]],
            ]
        )
        edges = np.unique(np.sort(edges, axis=1), axis=0)
        edge_array = edges

        edge_lens = np.linalg.norm(
            current_points[edge_array[:, 0]] - current_points[edge_array[:, 1]],
            axis=1,
        )
        n_edge_remaining = len(edge_lens)
        if init_bool:
            edge_to_idx = {tuple(e): i for i, e in enumerate(edge_array)}
            edge_dict, _, _ = make_edge_dict(mesh)
            current_weights = np.repeat(-999, 4*n_edge_remaining).astype(float)
            n_total_edges = n_edge_remaining
            for e in edge_dict:
                idx = edge_to_idx[e]
                weight = edge_dict[e]['weight']
                current_weights[idx] = weight
            init_edges = n_edge_remaining
            init_bool = False
            pbar = tqdm(total=n_edge_remaining//3, desc='Edge Reduction', dynamic_ncols=True)

        # stop if already super coarse
        if current_faces.shape[0] <= min_faces or n_edge_remaining <= 6:
            print("hit coarse stop:", current_faces.shape[0], "faces")
            break

        #edge_del_idx = np.argsort(edge_lens)
        #np.random.shuffle(edge_del_idx[:(n_edge_remaining // 5)]) if (n_edge_remaining / init_edges) > randomization_percentage else None
        # degris:
        # 1. Get degrees of all vertices
        deggy = np.bincount(np.ravel(current_faces))
        
        # 2. Calculate a score for every edge
        # This adds the degrees of both endpoints and multiplies by length
        # High degree + Long edge = Massive score (Top of the list)
        edge_scores = (deggy[edge_array[:, 0]] + deggy[edge_array[:, 1]]) * edge_lens
        
        # 3. Create the flat edge_del_idx
        edge_del_idx = np.argsort(edge_scores)[::-1]

        accepted = False
        vidx_0 = vidx_1 = None
        fidx_del = None

        for idx in edge_del_idx:
            v0 = edge_array[idx, 0]
            v1 = edge_array[idx, 1]

            # simulate collapse v0,v1 -> next_vid
            mask_edge_faces = (
                np.sum(np.isin(current_faces, [v0, v1]), axis=1) == 2
            )
            fidx_del = np.where(mask_edge_faces)[0]
            face_del_temp = current_faces[fidx_del]
            faces_temp = np.delete(current_faces, fidx_del, axis=0)
            faces_temp[np.isin(faces_temp, [v0, v1])] = next_vid

            # 1) no degenerate triangles: 3 distinct verts per face
            a, b, c = faces_temp.T
            good_tri = (a != b) & (b != c) & (a != c)
            if not np.all(good_tri):
                continue

            # 2) no duplicate faces (up to permutation)
            ordered = np.sort(faces_temp, axis=1)
            if np.unique(ordered, axis=0).shape[0] != faces_temp.shape[0]:
                continue

            # 3) Euler characteristic stays 2 (genus-0)
            V = np.unique(faces_temp).size
            e_sim = np.vstack(
                [
                    faces_temp[:, [0, 1]],
                    faces_temp[:, [1, 2]],
                    faces_temp[:, [0, 2]],
                ]
            )
            e_sim = np.unique(np.sort(e_sim, axis=1), axis=0)
            E = e_sim.shape[0]
            F = faces_temp.shape[0]
            chi = V - E + F
            if chi != 2:
                continue

            # if we get here, collapse is accepted
            prog_face_del.append(((fidx_del[0], face_del_temp[0]), (fidx_del[1], face_del_temp[1])))
            vidx_0, vidx_1 = v0, v1
            replace_idx_0 = np.where(current_faces == vidx_0)
            replace_idx_1 = np.where(current_faces == vidx_1)
            prog_idx_replace.append(((replace_idx_0, vidx_0), (replace_idx_1, vidx_1)))
            old_faces = current_faces[np.any(np.isin(current_faces, [vidx_0, vidx_1]), axis=1)]
            current_faces = faces_temp
            accepted = True
            pbar.update(1)
            break

        if not accepted:
            print("no valid collapse found; stopping")
            break

        # create the new vertex position (midpoint of v0,v1)
        mid = (current_points[vidx_0] + current_points[vidx_1]) / 2
        current_points[next_vid] = mid
        

        rel_faces = current_faces[np.any(current_faces == next_vid, axis=1)]
        rel_edges = np.vstack([rel_faces[:,[0,1]], rel_faces[:,[0,2]], rel_faces[:,[1,2]]])
        rel_edges = np.unique(np.sort(rel_edges, axis=1), axis=0)
        new_weights = cotan_weights_for_edge_list(rel_edges, current_points, current_faces)
        weight_change_idx = []
        prev_weights = []
        for i in range(len(new_weights)):
            edge_tuple = tuple(rel_edges[i])
            weight = new_weights[i]
            if edge_tuple not in edge_to_idx:
                widx = n_total_edges
                n_total_edges += 1
                if n_total_edges == len(current_weights):
                    current_weights = np.hstack([current_weights, np.repeat(-999, 1000)])
                edge_to_idx[edge_tuple] = widx
            else:
                widx = edge_to_idx[edge_tuple]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = weight
            weight_change_idx.append(widx)

        old_edges = np.vstack([old_faces[:,[0,1]], old_faces[:,[0,2]], old_faces[:,[1,2]]])
        old_edges = np.unique(np.sort(old_edges, axis=1), axis=0)
        old_edges = old_edges[np.any(np.isin(old_edges, [vidx_0, vidx_1]), axis=1)]
        old_edges_idx = []
        for e in old_edges:
            widx = edge_to_idx[tuple(e)]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = 0
            old_edges_idx.append(widx)
        undo = (weight_change_idx + old_edges_idx, prev_weights)
        prog_weight.append(undo)
        # prog_weight.append(((np.array(weight_change_idx), new_weights), (np.array(old_edges_idx), np.zeros(len(old_edges_idx)))))
        next_vid += 1

        # expand point buffer if needed
        if next_vid == current_points.shape[0]:
            current_points = np.vstack(
                [current_points, np.zeros((50, 3))]
            )

    zeros = np.all(current_points == 0, axis=1)
    if np.any(zeros):
        current_points = current_points[:np.argmax(zeros)]
    if np.any(current_weights == -999):
        current_weights = current_weights[:np.argmin(current_weights)]
    print(f"lifetime edges: {len(current_weights)}!")
    return current_points, current_faces, prog_idx_replace[::-1], prog_face_del[::-1], edge_to_idx, current_weights, prog_weight[::-1]

def progressive_mesh_deg2(mesh, min_faces=4, randomization_percentage=0.95, expire_time=100):
    centy = mesh.center_of_mass()
    # faces as (F,3) int array
    current_faces = mesh.faces.reshape(-1, 4)[:, 1:]
    current_points = mesh.points.copy()

    # preallocate some extra room for new vertices
    current_points = np.vstack(
        [current_points, np.zeros((current_points.shape[0], 3))]
    )
    next_vid = mesh.n_points  # index for next new vertex

    prog_face_del = []
    prog_idx_replace = []
    prog_weight = []
    

    start = time.time()
    init_bool = True
    while True:
        if time.time() - start > expire_time:
            print("expired...")
            break

        # build undirected edge list from current_faces
        edges = np.vstack(
            [
                current_faces[:, [0, 1]],
                current_faces[:, [1, 2]],
                current_faces[:, [0, 2]],
            ]
        )
        edges = np.unique(np.sort(edges, axis=1), axis=0)
        edge_array = edges

        edge_lens = np.linalg.norm(
            current_points[edge_array[:, 0]] - current_points[edge_array[:, 1]],
            axis=1,
        )
        n_edge_remaining = len(edge_lens)
        if init_bool:
            edge_to_idx = {tuple(e): i for i, e in enumerate(edge_array)}
            edge_dict, _, _ = make_edge_dict(mesh)
            current_weights = np.repeat(-999, 4*n_edge_remaining).astype(float)
            n_total_edges = n_edge_remaining
            for e in edge_dict:
                idx = edge_to_idx[e]
                weight = edge_dict[e]['weight']
                current_weights[idx] = weight
            init_edges = n_edge_remaining
            init_bool = False
            pbar = tqdm(total=n_edge_remaining//3, desc='Edge Reduction', dynamic_ncols=True)

        # stop if already super coarse
        if current_faces.shape[0] <= min_faces or n_edge_remaining <= 6:
            print("hit coarse stop:", current_faces.shape[0], "faces")
            break

        #edge_del_idx = np.argsort(edge_lens)
        #np.random.shuffle(edge_del_idx[:(n_edge_remaining // 5)]) if (n_edge_remaining / init_edges) > randomization_percentage else None
        # degris:
        # 1. Get degrees of all vertices
        deggy = np.bincount(np.ravel(current_faces))
        
        # 2. Calculate a score for every edge
        # This adds the degrees of both endpoints and multiplies by length
        # High degree + Long edge = Massive score (Top of the list)
        edge_scores = (deggy[edge_array[:, 0]] + deggy[edge_array[:, 1]]) * edge_lens
        
        # 3. Create the flat edge_del_idx
        edge_del_idx = np.argsort(edge_scores)

        accepted = False
        vidx_0 = vidx_1 = None
        fidx_del = None

        for idx in edge_del_idx:
            v0 = edge_array[idx, 0]
            v1 = edge_array[idx, 1]

            # simulate collapse v0,v1 -> next_vid
            mask_edge_faces = (
                np.sum(np.isin(current_faces, [v0, v1]), axis=1) == 2
            )
            fidx_del = np.where(mask_edge_faces)[0]
            face_del_temp = current_faces[fidx_del]
            faces_temp = np.delete(current_faces, fidx_del, axis=0)
            faces_temp[np.isin(faces_temp, [v0, v1])] = next_vid

            # 1) no degenerate triangles: 3 distinct verts per face
            a, b, c = faces_temp.T
            good_tri = (a != b) & (b != c) & (a != c)
            if not np.all(good_tri):
                continue

            # 2) no duplicate faces (up to permutation)
            ordered = np.sort(faces_temp, axis=1)
            if np.unique(ordered, axis=0).shape[0] != faces_temp.shape[0]:
                continue

            # 3) Euler characteristic stays 2 (genus-0)
            V = np.unique(faces_temp).size
            e_sim = np.vstack(
                [
                    faces_temp[:, [0, 1]],
                    faces_temp[:, [1, 2]],
                    faces_temp[:, [0, 2]],
                ]
            )
            e_sim = np.unique(np.sort(e_sim, axis=1), axis=0)
            E = e_sim.shape[0]
            F = faces_temp.shape[0]
            chi = V - E + F
            if chi != 2:
                continue

            # if we get here, collapse is accepted
            prog_face_del.append(((fidx_del[0], face_del_temp[0]), (fidx_del[1], face_del_temp[1])))
            vidx_0, vidx_1 = v0, v1
            replace_idx_0 = np.where(current_faces == vidx_0)
            replace_idx_1 = np.where(current_faces == vidx_1)
            prog_idx_replace.append(((replace_idx_0, vidx_0), (replace_idx_1, vidx_1)))
            old_faces = current_faces[np.any(np.isin(current_faces, [vidx_0, vidx_1]), axis=1)]
            current_faces = faces_temp
            accepted = True
            pbar.update(1)
            break

        if not accepted:
            print("no valid collapse found; stopping")
            break

        # create the new vertex position (midpoint of v0,v1)
        mid = (current_points[vidx_0] + current_points[vidx_1]) / 2
        current_points[next_vid] = mid
        

        rel_faces = current_faces[np.any(current_faces == next_vid, axis=1)]
        rel_edges = np.vstack([rel_faces[:,[0,1]], rel_faces[:,[0,2]], rel_faces[:,[1,2]]])
        rel_edges = np.unique(np.sort(rel_edges, axis=1), axis=0)
        new_weights = cotan_weights_for_edge_list(rel_edges, current_points, current_faces)
        weight_change_idx = []
        prev_weights = []
        for i in range(len(new_weights)):
            edge_tuple = tuple(rel_edges[i])
            weight = new_weights[i]
            if edge_tuple not in edge_to_idx:
                widx = n_total_edges
                n_total_edges += 1
                if n_total_edges == len(current_weights):
                    current_weights = np.hstack([current_weights, np.repeat(-999, 1000)])
                edge_to_idx[edge_tuple] = widx
            else:
                widx = edge_to_idx[edge_tuple]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = weight
            weight_change_idx.append(widx)

        old_edges = np.vstack([old_faces[:,[0,1]], old_faces[:,[0,2]], old_faces[:,[1,2]]])
        old_edges = np.unique(np.sort(old_edges, axis=1), axis=0)
        old_edges = old_edges[np.any(np.isin(old_edges, [vidx_0, vidx_1]), axis=1)]
        old_edges_idx = []
        for e in old_edges:
            widx = edge_to_idx[tuple(e)]
            prev_weights.append(current_weights[widx])
            current_weights[widx] = 0
            old_edges_idx.append(widx)
        undo = (weight_change_idx + old_edges_idx, prev_weights)
        prog_weight.append(undo)
        # prog_weight.append(((np.array(weight_change_idx), new_weights), (np.array(old_edges_idx), np.zeros(len(old_edges_idx)))))
        next_vid += 1

        # expand point buffer if needed
        if next_vid == current_points.shape[0]:
            current_points = np.vstack(
                [current_points, np.zeros((50, 3))]
            )

    zeros = np.all(current_points == 0, axis=1)
    if np.any(zeros):
        current_points = current_points[:np.argmax(zeros)]
    if np.any(current_weights == -999):
        current_weights = current_weights[:np.argmin(current_weights)]
    print(f"lifetime edges: {len(current_weights)}!")
    return current_points, current_faces, prog_idx_replace[::-1], prog_face_del[::-1], edge_to_idx, current_weights, prog_weight[::-1]

# selecting on lowest average degree of the two vertices of an edge works much better than choosing by highest!