In [1]:
#!pip install pyvista
#!pip install pyvista[jupyter]
#!pip install pyvistaqt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyvista as pv
from pyvistaqt import BackgroundPlotter
import time as time
from skimage.measure import marching_cubes
import nibabel as nib
from nibabel.processing import resample_from_to
import imageio
import pickle
from scipy.optimize import brentq
pv.set_jupyter_backend('none')    # if you want the trame-based UI

def print_data(mesh):
    print(mesh.point_data)
    print(mesh.cell_data)
    print(mesh.field_data)

def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)

def save_pickle(obj, path):
    path = Path(path)
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
    os.replace(tmp, path)  # atomic replace

In [3]:
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
import pyvista as pv
from typing import Callable, Dict, Tuple, Any, Optional
from scipy.optimize import minimize

@dataclass(slots=True)
class SquareRootNormalMesh:
    uid: str = field(default_factory=lambda: uuid.uuid4().hex())
    name = ""
    verts: np.ndarray
    faces_pv: np.ndarray
    q_face: np.ndarray
    verts_sph: np.ndarray
    mesh: pv.PolyData
    mesh_sph: pv.PolyData = field(init=False, repr=False)
    _cache: dict = field(default_factory=dict, repr=False)
    pairwise_parameters: dict

    def __init__(self, pv_mesh: pv.PolyData):
        self.verts = pv_mesh.points
        self.faces = pv_mesh.faces
        self.q_face = get_srnf_from_mesh(pv_mesh)

        self.verts_sph = C2_adaptive(pv_mesh)
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces)

    def __post_init__(self):
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces_pv)
        F = len(self.faces_pv)//4
        assert self.q_face.shape == (F, 3)

    def from_voxels(self, voxel_image: np.ndarray):
        self.mesh = pv.PolyData(skimage.marching_cubes(voxel_image))
        self.verts = self.mesh.points
        self.faces = self.mesh.faces
        self.q_face = get_srnf_from_mesh(self.mesh)

        self.verts_sph = C2_adaptive(self.mesh)
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces)
        
    def sample(self, U: np.ndarray) -> np.ndarray:
        idx = self.mesh_sph.find_closest_cell(U)
        return self.q_face[idx]

    def preindex(self, U: np.ndarray) -> np.ndarray:
        return self.mesh_sph.find_closest_cell(U)

    def gather(self, idx: np.ndarray) -> np.ndarray:
        return self.q_face[idx]

_IncompleteInputError: incomplete input (3302023662.py, line 10)

In [None]:
from dataclasses import dataclass, field
import numpy as np, pyvista as pv, uuid

def make_edge_dict(mesh):
    """
    Extract edge data from a mesh from vertex and face lists.
    Returns a dictionary with each edge as a key in the form
    of a sorted tuple, [(v0, v1)] if v0 < v1, [(v1, v0)] if v1 < v0.
    Each key/edge will contain another dictionary with keys:
    [opposite_vertex_0 : int, opposite_vertex_1 : int, 
     weight (w = cot(opposite_angle_0) + cot(opposite_angle_1)) : float]

     Parameters
     ----------
     mesh : pv.PolyData
     
     Returns
     -------
     edge_dict : dictionary
         Contains a dict for each edge [(v0,v1)] with keys opposite_vertex_0,
         opposite_vertex_1, and weight. opposite_vertex_i contains
         float of the angle in radians made by the three vertices
         v0, opposite_vertex_i, v1
    """
    verts = mesh.points
    faces = mesh.faces
    edge_dict = {}
    if faces.ndim == 1:
        faces = faces.reshape(-1, 4)
    
    for tri in faces[:,1:]:
        i, j, k = map(int, tri)
        for (a,b,c) in [(i,j,k),(j,k,i),(k,i,j)]:
            e = tuple(sorted((a,b)))   # e.g. (min(a,b),max(a,b))
            if e not in edge_dict:
                edge_dict[e] = {}
            v0 = verts[a]
            v1 = verts[b]
            v2 = verts[c]
            e0 = v0 - v2
            e1 = v1 - v2
            costheta = np.dot(e0, e1) / (np.linalg.norm(e0) * np.linalg.norm(e1))
            costheta = np.clip(costheta, -1, 1)
            angle = np.arccos(costheta)
            edge_dict[e][c] = float(angle)     # store the vertex “across” this edge
            if len(edge_dict[e]) == 2:
                angle0 = list(edge_dict[e].items())[0][1]
                w = np.cos(angle0)/np.sin(angle0) + np.cos(angle)/np.sin(angle)
                edge_dict[e]['weight'] = float(w)
    edges_idx = np.array([(i[0], i[1]) for i in edge_dict])
    try:
        weights = np.array([edge_dict[i]['weight'] for i in edge_dict])
    except:
        weights = np.repeat(-999, len(edge_dict.keys()))
        counter = 0
        for i in edge_dict:
            if 'weight' in edge_dict[i].keys():
                weights[counter] = edge_dict[i]['weight']
            counter += 1
    return edge_dict, edges_idx, weights

def string_energy_vec(edges_idx, weights, phi):
    # # phi: (N,3) array of vertex coords on the sphere
    # phi = mesh['spherical_param']
    # edges_idx = mesh['edges_idx']
    # weights = mesh['edges_weights']
    i0 = edges_idx[:,0]
    i1 = edges_idx[:,1]
    # shape (E,3) differences
    diffs = phi[i0] - phi[i1]
    # squared‐norm per edge, shape (E,)
    sqnorm = np.einsum('ij,ij->i', diffs, diffs)
    # sum w * sqnorm
    return float(np.dot(weights, sqnorm))

def extract_vertex_face_energies(og_param, sphere_param, faces, edge_array, edge_weights):
    verts = sphere_param
    original_points = og_param
    faces = faces
    edges_idx = edge_array
    weights = edge_weights
    if faces.ndim == 1:
        faces = faces.reshape(-1, 4)
        faces = faces[:,1:]
    vertex_energy = np.zeros(verts.shape[0])
    face_energy = np.zeros(faces.shape[0])
    edge_energy = np.zeros(edges_idx.shape[0])
    for i in range(edges_idx.shape[0]):
        vertex_idx0 = edges_idx[i,0]
        vertex_idx1 = edges_idx[i,1]
        faces_idx = np.where(np.sum((faces == vertex_idx0)|(faces == vertex_idx1), axis = 1) == 2)[0]
        v0 = verts[vertex_idx0]
        v1 = verts[vertex_idx1]
        weight = weights[i]
        energy = weight * np.linalg.norm(v0 - v1)**2

        vertex_energy[vertex_idx0] += weight
        vertex_energy[vertex_idx1] += weight
        face_energy[faces_idx[0]] += weight
        face_energy[faces_idx[1]] += weight
        edge_energy[i] += weight
    return edge_energy, face_energy, vertex_energy

def d_energy_vec(edges_idx, weights, phi):
    """
    Vectorized ∂E/∂Phi for all vertices at once.
    
    edges_idx : (E,2) int array of [i,j] pairs
    weights   : (E,)  float array of w_ij
    phi       : (N,3) float array of Phi positions

    returns grad : (N,3) float array of 2*sum_j w_ij*(Phi_i - Phi_j)
    """
    i0 = edges_idx[:, 0]      # shape = (E,)
    i1 = edges_idx[:, 1]      # shape = (E,)

    # 1) compute the per-edge vector differences
    diffs    = phi[i0] - phi[i1]         # (E,3)

    # 2) weight them
    w_diffs  = diffs * weights[:, None]  # (E,3)

    # 3) scatter‐add into a per-vertex accumulator
    grad     = np.zeros_like(phi)        # (N,3)
    np.add.at(grad, i0,  w_diffs)        # grad[i0] +=  w_diffs
    np.add.at(grad, i1, -w_diffs)        # grad[i1] -=  w_diffs

    # 4) if you want the true energy‐gradient, multiply by 2
    # grad *= 2

    return grad

def C2_adaptive(verts, faces, normals, edge_array, weights, tol=1e-10, max_iter=10_000, initial_dt=0.1):
    # 0) Setup
    verts, faces = mesh.points, mesh.faces
    phi = mesh.point_normals.copy()
    # build edge/weight arrays once
    edge_array
    weights
    if missing_weights:
        weights[np.where(weights_array = -999)] = 1

    # initial energy
    E = string_energy_vec(edge_array, weights, phi)
    energies = [E]
    delta_t = initial_dt

    # 1) Main loop
    for it in range(1, max_iter+1):
        # Find raw gradient. This is the direction of the set of vectors we can add to
        #     phi_i the maximizes total string energy
        d_phi_energy = d_energy_vec(edge_array, weights, phi)
        
        # The raw gradient is the direction change for each vector in R3, but we need to
        #     find the optimal gradient that keeps the results of phi_i - gradient on the
        #     surface of the sphere (in S2)
        proj_grad = d_phi_energy - (d_phi_energy * phi).sum(axis = 1)[:,None] * phi

        # 2) Backtracking line‐search (up to 10 tries)
        success = False
        for _ in range(6):
            # Proj_grad is the direction that maximizes total string energy, we want to find
            #     the parameterization that minimizes it so we subtract. Also scale by delta_t
            #     because we cannot be sure what magnitude best minimizes
            phi_trial = phi - delta_t * proj_grad
            
            # The new parameterization may be slightly outside S2 due to float rounding and
            #     other reasons from the computer, so we renormalize
            phi_trial = phi_trial / np.linalg.norm(phi_trial, axis = 1)[:,None]
            E_trial   = string_energy_vec(edge_array, weights, phi_trial)

            if E_trial < E:
                # accept
                phi, E = phi_trial, E_trial
                energies.append(E)
                delta_t *= 1.25
                success = True
                break
            else:
                delta_t *= 0.25

        if not success:
            print(f"[C2_adaptive] line‐search failed at iteration {it}")
            break

        # 3) Check convergence on energy drop
        if (energies[-2] - energies[-1]) < tol:
            print(f"[C2_adaptive] converged in {it} iterations; E = {E:.6e}")
            break

        # 4) Optional logging
        if it % 50 == 0:
            print(f" iter {it:4d}  E = {E:.6e}  dt = {delta_t:.2e}")
    return phi, energies

def srnf_from_mesh(mesh):
    mesh.compute_normals(point_normals=False, cell_normals=True, inplace=True)
    mesh['Area'] = mesh.compute_cell_sizes(length=False, area=True, volume=False)['Area']
    u = mesh.cell_data["Normals"]                # (F, 3)
    A = mesh.cell_data["Area"]                  # (F,)
    q = u * np.sqrt(2.0 * A)[:, None]           # (F, 3)
    return q

@dataclass(slots=True)
class SquareRootNormalMesh:
    verts: np.ndarray
    verts_sph: np.ndarray
    faces_pv:  np.ndarray
    q_face:    np.ndarray
    name: str = ""
    uid:  str = field(default_factory=lambda: uuid.uuid4().hex)
    mesh_sph: pv.PolyData = field(init=False, repr=False)
    edge_dict: dict
    edge_array: np.ndarray
    edge_weights: np.ndarray
    sphere_param_edge_energy: np.ndarray
    sphere_param_face_energy: np.ndarray
    sphere_param_vertex_energy: np.ndarray
    sphere_param_log: list

    def __post_init__(self):
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces_pv)

    @classmethod
    def from_polydata(cls, pv_mesh: pv.PolyData, **kw):
        self.edge_dict, self.edge_array, self.edge_weight = make_edge_dict(pv_mesh)
        self.verts = pv_mesh.points
        verts_sph = C2_adaptive(self.verts, self.faces_pv, self.mesh.point_normals, self.edge_array, self.edge_weights, max_iter = """GET MAX ITER FROM CONSTRUCTOR KEYWORD SOMEHOW""")
        self.sphere_param_edge_energy, self.sphere_param_face_energy, self.param_vertex_energy = extract_vertex_face_energies(self.verts, self.verts_sph, self.faces_pv, self.edge_array, self.edge_weights)
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces_pv)
        self.q_face = srnf_from_mesh(self.mesh)
        return cls(verts_sph=verts_sph, faces_pv=pv_mesh.faces, q_face=q_face, name=kw.get("name",""))

    @classmethod
    def from_voxels(cls, vox: np.ndarray, *, iso=0.5, spacing=(1,1,1), smoothing=None, **kw):
        from skimage.measure import marching_cubes
        v,f,_,_ = marching_cubes(vox, level=iso, spacing=spacing)
        faces_pv = np.c_[np.full((len(f),1),3,np.int32), f.astype(np.int32)].ravel()
        pv_mesh = pv.PolyData(v, faces_pv)
        if smoothing: pv_mesh = pv_mesh.smooth(n_iter=smoothing)
        return cls.from_polydata(pv_mesh, **kw)

    def sample(self, U: np.ndarray) -> np.ndarray:
        idx = self.mesh_sph.find_closest_cell(U)
        return self.q_face[idx]

    def preindex(self, U: np.ndarray) -> np.ndarray:
        return self.mesh_sph.find_closest_cell(U)

    def gather(self, idx: np.ndarray) -> np.ndarray:
        return self.q_face[idx]

In [None]:
from dataclasses import dataclass, field
import numpy as np, pyvista as pv, uuid

# ---------- your helper funcs (kept as-is or lightly cleaned) ----------
def _to_pv_faces(tri_faces: np.ndarray) -> np.ndarray:
    if tri_faces.ndim == 2 and tri_faces.shape[1] == 3:
        f = np.concatenate([np.full((len(tri_faces),1), 3, dtype=np.int32),
                            tri_faces.astype(np.int32)], axis=1).ravel()
        return f
    return tri_faces

def make_edge_dict(mesh: pv.PolyData):
    verts = mesh.points
    faces = mesh.faces
    if faces.ndim == 1: faces = faces.reshape(-1, 4)  # [3,i,j,k]
    edge_dict = {}
    for tri in faces[:, 1:]:
        i, j, k = map(int, tri)
        for (a,b,c) in [(i,j,k),(j,k,i),(k,i,j)]:
            e = tuple(sorted((a,b)))
            if e not in edge_dict: edge_dict[e] = {}
            v0, v1, v2 = verts[a], verts[b], verts[c]
            e0, e1 = v0 - v2, v1 - v2
            cosang = np.dot(e0, e1) / (np.linalg.norm(e0) * np.linalg.norm(e1) + 1e-15)
            ang = np.arccos(np.clip(cosang, -1.0, 1.0))
            edge_dict[e][c] = float(ang)
            if len(edge_dict[e]) == 2:
                ang0 = list(edge_dict[e].values())[0]
                w = np.cos(ang0)/np.sin(ang0 + 1e-15) + np.cos(ang)/np.sin(ang + 1e-15)
                edge_dict[e]['weight'] = float(w)
    edges_idx = np.array([(i[0], i[1]) for i in edge_dict], dtype=int)
    weights = np.array([edge_dict[e].get('weight', 0.0) for e in edge_dict], dtype=float)
    return edge_dict, edges_idx, weights

def string_energy_vec(edges_idx, weights, phi):
    i0, i1 = edges_idx[:,0], edges_idx[:,1]
    diffs = phi[i0] - phi[i1]
    sq = np.einsum('ij,ij->i', diffs, diffs)
    return float(np.dot(weights, sq))

def d_energy_vec(edges_idx, weights, phi):
    i0, i1 = edges_idx[:,0], edges_idx[:,1]
    diffs   = phi[i0] - phi[i1]
    wdiffs  = diffs * weights[:,None]
    grad = np.zeros_like(phi)
    np.add.at(grad, i0,  wdiffs)
    np.add.at(grad, i1, -wdiffs)
    return grad

def extract_vertex_face_energies(verts_sph, faces_pv, edges_idx, weights):
    faces = faces_pv.reshape(-1,4)[:,1:]
    V = verts_sph.shape[0]
    F = faces.shape[0]
    E = edges_idx.shape[0]
    vE = np.zeros(V); fE = np.zeros(F); eE = np.zeros(E)
    for e in range(E):
        i, j = edges_idx[e]
        w = weights[e]
        # faces incident to edge (i,j)
        faces_idx = np.where(np.sum((faces == i) | (faces == j), axis=1) == 2)[0]
        vE[i] += w; vE[j] += w
        if len(faces_idx) > 0: fE[faces_idx] += w
        eE[e] = w
    return eE, fE, vE

def C2_adaptive(verts_sph_init, edges_idx, weights, *,
                tol=1e-10, max_iter=10000, initial_dt=0.1):
    """Simplified: improve a sphere embedding by edge-energy descent on S^2."""
    phi = verts_sph_init.copy()
    E = string_energy_vec(edges_idx, weights, phi)
    delta_t = initial_dt
    for _ in range(1, max_iter+1):
        grad = d_energy_vec(edges_idx, weights, phi)
        proj = grad - (np.einsum('ij,ij->i', grad, phi))[:,None] * phi
        ok = False
        for _ in range(6):
            trial = phi - delta_t * proj
            trial /= np.linalg.norm(trial, axis=1, keepdims=True) + 1e-15
            E_new = string_energy_vec(edges_idx, weights, trial)
            if E_new < E:
                phi, E = trial, E_new
                delta_t *= 1.25
                ok = True
                break
            delta_t *= 0.25
        if not ok or delta_t < 1e-12 or abs(E_new - E) < tol:
            break
    return phi, E  # (final verts on S^2, last energy)

def srnf_from_mesh(mesh: pv.PolyData) -> np.ndarray:
    mesh = mesh.copy()
    mesh.compute_normals(point_normals=False, cell_normals=True, inplace=True)
    mesh['Area'] = mesh.compute_cell_sizes(length=False, area=True, volume=False)['Area']
    u = mesh.cell_data["Normals"]
    A = mesh.cell_data["Area"]
    return u * np.sqrt(2.0 * A)[:, None]
# -----------------------------------------------------------------------


@dataclass(slots=True)
class SquareRootNormalMesh:
    # primary input is a PolyData
    pv_mesh: pv.PolyData
    name: str = ""
    uid:  str = field(default_factory=lambda: uuid.uuid4().hex)

    # derived/computed
    verts: np.ndarray = field(init=False, repr=False)
    faces_pv: np.ndarray = field(init=False, repr=False)
    verts_sph: np.ndarray = field(init=False, repr=False)
    q_face: np.ndarray = field(init=False, repr=False)
    mesh_sph: pv.PolyData = field(init=False, repr=False)

    # connectivity/weights & diagnostics
    edge_dict: dict = field(init=False, repr=False)
    edge_array: np.ndarray = field(init=False, repr=False)
    edge_weights: np.ndarray = field(init=False, repr=False)
    sphere_param_edge_energy: float = field(init=False, repr=False)
    sphere_param_face_energy: np.ndarray = field(init=False, repr=False)
    sphere_param_vertex_energy: np.ndarray = field(init=False, repr=False)

    def __post_init__(self):
        # core geometry
        self.verts = self.pv_mesh.points
        self.faces_pv = self.pv_mesh.faces

        # edges & cotan weights
        self.edge_dict, self.edge_array, self.edge_weights = make_edge_dict(self.pv_mesh)

        # initial spherical param: (use normals, or normalized verts, or your existing map)
        # here I’ll start from normalized vertex positions as a simple seed:
        seed = self.verts / (np.linalg.norm(self.verts, axis=1, keepdims=True) + 1e-15)

        # improve with your C2_adaptive energy
        self.verts_sph, _ = C2_adaptive(seed, self.edge_array, self.edge_weights,
                                        tol=1e-10, max_iter=2000, initial_dt=0.1)

        # spherical mesh for sampling
        self.mesh_sph = pv.PolyData(self.verts_sph, self.faces_pv)

        # SRNF (per-face)
        self.q_face = srnf_from_mesh(self.pv_mesh)

        # energies per entity (diagnostics)
        eE, fE, vE = extract_vertex_face_energies(self.verts_sph, self.faces_pv,
                                                  self.edge_array, self.edge_weights)
        self.sphere_param_edge_energy = float(np.sum(eE))
        self.sphere_param_face_energy = fE
        self.sphere_param_vertex_energy = vE

    # ---- alternate constructor from voxels ----
    @classmethod
    def from_voxels(cls, vox: np.ndarray, *, iso=0.5, spacing=(1,1,1), smoothing=None, name=""):
        from skimage.measure import marching_cubes
        v,f,_,_ = marching_cubes(vox, level=iso, spacing=spacing)
        faces_pv = _to_pv_faces(f)
        mesh = pv.PolyData(v, faces_pv)
        if smoothing:
            mesh = mesh.smooth(n_iter=smoothing)
        return cls(mesh, name=name)

    # ---- sampling helpers ----
    def sample(self, U: np.ndarray) -> np.ndarray:
        idx = self.mesh_sph.find_closest_cell(U)
        return self.q_face[idx]

    def preindex(self, U: np.ndarray) -> np.ndarray:
        return self.mesh_sph.find_closest_cell(U)

    def gather(self, idx: np.ndarray) -> np.ndarray:
        return self.q_face[idx]
