In [13]:
import numpy as np
import torch
import trimesh
import scipy.io
import pyvista as pv
from geomloss import SamplesLoss
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import grad

# from point_matching import MatchPoints, deform_mesh

plt.rcParams['figure.figsize'] = [16, 10]

In [14]:
# Optional:

# KeOps library for kernel convolutions -- useless for small datasets
#!pip install pykeops 
use_keops = False # use of 

# pyvista for displaying 3D graphics
#!pip install pyvista[all]
use_pyvista = True

if use_keops:
    from pykeops.torch import LazyTensor

### Kernel Functions

In [15]:
# Gaussian Kernel (K(x,y)b)_i = sum_j exp(-|xi-yj|^2/sigma^2)bj
def GaussKernel(sigma):
    oos2 = 1/sigma**2

    def K(x,y,b):
        x,y = x[:,None,:],y[None,:,:]
        if use_keops:
            x,y = LazyTensor(x),LazyTensor(y)

        return (-oos2*((x-y)**2).sum(dim=2)).exp()@b
    
    return K

# defines a composite kernel that combines a Gaussian (radial basis function) kernel with a 
# linear kernel, tailored for use with vectors representing (measure similarity) geometric  
# entities (such as normals or directions) in addition to positions.
def GaussLinKernel(sigma, lib="keops"):
    oos2 = 1/sigma**2
    
    def K(x,y,u,v,b):
        # calculates the similarity based on the Euclidean distance between points  
        # x and y in a high-dimensional space
        Kxy = torch.exp(-oos2*torch.sum((x[:,None,:]-y[None,:,:])**2,dim=2))
        # computes the squared dot product between corresponding vectors u and v 
        # associated with points x and y, respectively. This part captures the 
        # similarity in directions (e.g., surface normals) at the points
        Sxy = torch.sum(u[:,None,:]*v[None,:,:],dim=2)**2
        
        # composite kernel
        return (Kxy*Sxy)@b
    
    return K

### Ordinary Differential Equations (ODEs) solver and Optimizer

In [16]:
# numerical integrator for solving ordinary differential equations (ODEs)
# solves the Hamiltonian system dynamics during the shape deformation process
def RalstonIntegrator(nt=10):
    # nt:  number of time steps to divide the integration interval into
    def f(ODESystem,x0,deltat=1.0):
        # x0: initial conditions (p0, q0)
        # deltat: total integration time
        x = tuple(map(lambda x:x.clone(), x0))
        dt = deltat/nt
        for i in range(nt):
            # computes the derivatives (system dynamics) of the current state x
            xdot = ODESystem(*x)
            # temporary state xi = x + (2*dt/3)*xdot
            # predicts the system state after two-thirds of the time step, 
            # guided by the initial derivative
            xi = tuple(map(lambda x,xdot:x+(2*dt/3)*xdot,x,xdot))
            # derivatives at this intermediate state xi
            xdoti = ODESystem(*xi)
            # final state for the time step is computed using the 
            # combination of the initial derivative and the intermediate derivative
            # weighted average of the initial and intermediate derivatives to estimate the next state
            x = tuple(map(lambda x,xdot,xdoti:x+(.25*dt)*(xdot+3*xdoti),x,xdot,xdoti))
            
        return x
    
    return f

# function to minimize the loss
def Optimize(loss,args,niter=5):
    # loss: function to compute the loss given the current set of parameters (args)
    # args: p0
    optimizer = torch.optim.LBFGS(args)
    losses = []
    print('performing optimization...')
    # repeatedly adjusts the parameters to minimize the loss function
    for i in range(niter):
        print("iteration ",i+1,"/",niter)
        def closure():
            # reset optimizer gradients to 0 (previous data doesn't affect the current update)
            optimizer.zero_grad()
            # compute loss
            L = loss(*args)
            losses.append(L.item())
            # backprop to compute the gradients of the loss w.r.t the parameters
            L.backward()
            
            return L
        
        # update optimizer parameters based on the loss
        optimizer.step(closure)
        
    print("Done.")

    return args, losses

### Implementation of LDDMM

In [17]:
# function of momentum p and position q to represents the total energy of the system
# measures the energy associated with the deformation, using the kernel K 
# to mediate the influence of points on each other.
def Hamiltonian(K):
    def H(p,q):
        return .5*(p*K(q,q,p)).sum()
        
    return H

# builds the Hamiltonian system that needs to be solved during the "shooting" process
# calculates the gradients of the Hamiltonian with respect to p and q, 
# which represent the rates of change of these quantities
# the system is defined by -Gq, Gp (gradients)
def HamiltonianSystem(K):
    H = Hamiltonian(K)
    
    def HS(p,q):
        Gp,Gq = grad(H(p,q),(p,q), create_graph=True)
        
        return -Gq,Gp
        
    return HS

# integrates the Hamiltonian system over time to find the end state (p, q) 
# starting from initial conditions (p0, q0)
def Shooting(p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K),(p0, q0), deltat)

# intégration des équations de flot
def Flow(x0, p0, q0, K, deltat=1.0, Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    
    def FlowEq(x,p,q):
        return (K(x,q,p),)+HS(p,q)
        
    return Integrator(FlowEq,(x0,p0,q0),deltat)[0]

# defines the loss function to be minimized, 
# combining the Hamiltonian (energy of the deformation) 
# and a data attachment loss 
def LDDMMloss(q0,K,dataloss,gamma=0.):
    # dataloss: measures the discrepancy between the deformed source shape and the target shape
    # q0: initial configuration of the source shape
    # K: kernel function
    # gamma: regularization parameter
    def loss(p0):
        # finding p0 that minimizes the loss
        # p, q: final momentum and deformed shape after the shooting process
        p,q = Shooting(p0,q0,K)
        # Hamiltonian(K): computes the energy of the initial configuration q0 with the initial momentum p0. 
        # This represents the energy required to deform the shape regularized by gamma
        # dataloss(q): computes the mismatch between the final deformed shape q and the target
        return gamma * Hamiltonian(K)(p0,q0) + dataloss(q)
    
    ############################# ToDo #############################
    # when adding contours use the flow function, set deltat parameter
    # flow -> dataloss for curve -> flow
    # extract contours and get corresponding coordinates in 3D
    ###############################################################
    
    return loss

### Data attachment function

In [18]:
# data attachment function for triangulated surfaces, varifold model
# focuses on matching geometric features like positions and normals without requiring explicit 
# point correspondence. this function is useful when the exact matching of points between shapes 
# is infeasible or not desired
def lossVarifoldSurf(FS,VT,FT,K):
    # VT: coordinates of the points of the target surface
    # FS,FT: indices of the triangles of the source and target surfaces
    # K: varifold kernel
    
    # compute centers (C), normals (N), and areas (L) of the triangles for a given surface
    # based on vertices V and faces F
    def CompCLNn(F,V):
        # V0, V1, V2: vertices of each triangle
        # center(C) of each triangle is calculated as the average of its vertices
        # normal(N): taking the cross product of two edges of the triangle, which is then normalized
        # area(L): length of the normal vector before normalization
        V0, V1, V2 = V.index_select(0,F[:,0]), V.index_select(0,F[:,1]), V.index_select(0,F[:,2])
        C, N = .5*(V0+V1+V2), .5*torch.linalg.cross(V1-V0,V2-V0)
        L = (N**2).sum(dim=1)[:,None].sqrt()
        
        return C,L,N/L
    
    CT,LT,NTn = CompCLNn(FT,VT)
    # self-interaction term for the target surface using the varifold kernel  K
    cst = (LT*K(CT,CT,NTn,NTn,LT)).sum()
    
    # calculates the varifold distance between the source and target surfaces
    
    # loss is formulated as the sum of the source self-interaction and the 
    # target self-interaction (cst), minus twice the cross term (interaction between their geometric features)
    # the intuition is to minimize the difference in geometric features (both position and direction of normals) 
    # between the source and target surfaces, thereby aligning them
    def loss(VS):
        CS,LS,NSn = CompCLNn(FS,VS)
        
        return cst + (LS*K(CS,CS,NSn,NSn,LS)).sum() - 2*(LS*K(CS,CT,NSn,NTn,LT)).sum()
    
    return loss

### Plotting functions

In [19]:
# result display function for landmark or point cloud data
def PlotRes2D(z, pts=None):
    def plotfun(q0,p0,Kv, showgrid=True):
        p,q = Shooting(p0,q0,Kv)
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        q0np, qnp, znp = q0.data.numpy(), q.data.numpy(), z.data.numpy()
        plt.plot(znp[:,0],znp[:,1],'.');
        plt.plot(q0np[:,0],q0np[:,1],'+');
        plt.plot(qnp[:,0],qnp[:,1],'o');
        plt.axis('equal');
        if showgrid:
            X = get_def_grid(p0,q0,Kv)
            plt.plot(X[0],X[1],'k',linewidth=.25);
            plt.plot(X[0].T,X[1].T,'k',linewidth=.25); 
        n,d = q0.shape
        nt = 20
        Q = np.zeros((n,d,nt))
        for i in range(nt):
            t = i/(nt-1)
            Q[:,:,i] = Shooting(t*p0,q0,Kv)[1].data.numpy()
        plt.plot(Q[:,0,:].T,Q[:,1,:].T,'y');
        if type(pts)!=type(None):
            phipts = Flow(pts,p0,q0,Kv).data
            plt.plot(phipts.numpy()[:,0],phipts.numpy()[:,1],'.b',markersize=.1);
    return plotfun

# display function for triangulated surface type data
def PlotRes3D(VS, FS, VT, FT, filename="deformation.html"):
    def plotfun(q0, p0, Kv, src_opacity=1, tgt_opacity=1, def_opacity=1, showgrid=True):
        # q0, p0: Initial vertices and momenta (optimized) for the source shape
        # simulate the deformation process and compute the final deformed shape q
        p,q = Shooting(p0,q0,Kv)
        # q0np, qnp: numpy arrays of the initial and final vertex positions
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        # numpy arrays of source and target faces
        FSnp,VTnp, FTnp = FS.data.numpy(),  VT.data.numpy(), FT.data.numpy() 
        if use_pyvista:
            p = pv.Plotter()
            opacity = 1
            # mesh for the initial shape
            p.add_mesh(surf_to_pv(q0np,FSnp), color='lightblue', opacity=src_opacity)
            # mesh for the deformed shape
            p.add_mesh(surf_to_pv(qnp,FSnp), color='lightcoral', opacity=def_opacity)
            # mesh for target shape
            p.add_mesh(surf_to_pv(VTnp,FTnp), color='lightgreen', opacity=tgt_opacity)
            if showgrid:
                ng = 20
                X = get_def_grid(p0,q0,Kv,ng=ng)
                for k in range(3):
                    for i in range(ng):
                        for j in range(ng):
                            p.add_mesh(lines_from_points(X[:,i,j,:].T))
                    X = X.transpose((0,2,3,1))
            # p.show()
            p.export_html(filename)
            p.close()
        else:
            fig = plt.figure();
            plt.axis('off')
            plt.title('LDDMM matching example')     
#             ax = Axes3D(fig, auto_add_to_figure=False)
            ax = Axes3D(fig)
            # triangular mesh for the initial shape
            ax.plot_trisurf(q0np[:,0],q0np[:,1],q0np[:,2],triangles=FSnp,alpha=.5)
            # triangular mesh for the deformed shape
            ax.plot_trisurf(qnp[:,0],qnp[:,1],qnp[:,2],triangles=FSnp,alpha=.5)
            # triangular mesh for the target shape
            ax.plot_trisurf(VTnp[:,0],VTnp[:,1],VTnp[:,2],triangles=FTnp,alpha=.5)
            if showgrid:
                ng = 20
                X = get_def_grid(p0,q0,Kv,ng=ng)
                for k in range(3):
                    for i in range(ng):
                        for j in range(ng):
                            ax.plot(X[0,i,j,:],X[1,i,j,:],X[2,i,j,:],'k',linewidth=.25);
                    X = X.transpose((0,2,3,1))
            fig.add_axes(ax)
    return plotfun

def get_def_grid(p0,q0,Kv,ng=50):
    d = p0.shape[1]
    p,q = Shooting(p0,q0,Kv)
    q0np, qnp = q0.data.numpy(), q.data.numpy()
    q0np, qnp = q0.data.numpy(), q.data.numpy()
    # calculates the minimum (a) and maximum (b) coordinates for the vertices in the 
    # initial (q0) and final (q) positions to establish the bounds of the grid
    a = list(np.min(np.vstack((q0np[:,k],qnp[:,k]))) for k in range(d))
    b = list(np.max(np.vstack((q0np[:,k],qnp[:,k]))) for k in range(d))
    # expands these bounds by 20% to ensure the grid extends slightly beyond the 
    # immediate area covered by the initial and deformed shapes
    sz = 0.2
    lsp = list(np.linspace(a[k]-sz*(b[k]-a[k]),b[k]+sz*(b[k]-a[k]),ng,dtype=np.float32) for k in range(d))
    X = np.meshgrid(*lsp)
    x = np.concatenate(list(X[k].reshape(ng**d,1) for k in range(d)),axis=1)
    # transform grid (x) according to the LDDMM mapping
    # Flow():  integrates how each point in the grid moves under the transformation 
    # defined by Kv, p0, and q0. This step effectively applies the diffeomorphic map to the entire grid
    phix = Flow(torch.from_numpy(x),p0,q0,Kv).detach().numpy()
    X = phix.transpose().reshape([d]+[ng]*d)
    return X

def lines_from_points(points):
    import pyvista as pv
    """Given an array of points, make a line set"""
    poly = pv.PolyData()
    poly.points = points
    cells = np.full((len(points) - 1, 3), 2, dtype=np.int_)
    cells[:, 1] = np.arange(0, len(points) - 1, dtype=np.int_)
    cells[:, 2] = np.arange(1, len(points), dtype=np.int_)
    poly.lines = cells
    return poly

def surf_to_pv(V,F):
    nf = F.shape[0]
    F = np.hstack((np.ones((nf,1),dtype="int")*3,F))
    F = F.flatten()
    surf = pv.PolyData(V,F)
    return surf

# fonction d'affichage pour des données de type surface triangulée
def PlotResSurf(VS,FS,VT,FT):
    def plotfun(q0,p0,Kv):
        fig = plt.figure();
        plt.axis('off')
        plt.title('LDDMM matching example')  
        p,q = Shooting(p0,q0,Kv)
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        FSnp,VTnp, FTnp = FS.data.numpy(),  VT.data.numpy(), FT.data.numpy()    
        ax = Axes3D(fig, auto_add_to_figure=False)
        ax.plot_trisurf(q0np[:,0],q0np[:,1],q0np[:,2],triangles=FSnp,alpha=.5)
        ax.plot_trisurf(qnp[:,0],qnp[:,1],qnp[:,2],triangles=FSnp,alpha=.5)
        ax.plot_trisurf(VTnp[:,0],VTnp[:,1],VTnp[:,2],triangles=FTnp,alpha=.5)
        fig.add_axes(ax)
    return plotfun

### Helper functions

In [20]:
# Function to load mesh and convert to tensors
def load_mesh_as_tensors(filepath):
    # Load the mesh
    mesh = trimesh.load(filepath, process=False)
    
    # Convert vertices and faces to PyTorch tensors
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
    faces = torch.tensor(mesh.faces, dtype=torch.long)
    
    return vertices, faces

### Surface matching with LDDMM (without constraints)

In [None]:
VS, FS = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_a_mesh_500.ply')
VT, FT = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_r_mesh_500.ply')

# Save the tensors into a .pt file
torch.save((VS, FS, VT, FT), 'vt_tensors/a-r_mesh.pt')

VS,FS,VT,FT = torch.load('vt_tensors/a-r_mesh.pt') 
q0 = VS.clone().detach().requires_grad_(True)
Kv = GaussKernel(sigma=20)
# multiscale kernel
Dataloss = lossVarifoldSurf(FS,VT,FT,GaussLinKernel(sigma=20)) # try smaller values (15-20)
loss = LDDMMloss(q0,Kv,Dataloss, gamma=0.1)
p0 = torch.zeros(q0.shape, requires_grad=True)

p0, losses = Optimize(loss,[p0],niter=10) # 100 iterations
# PlotRes3D(VS,FS,VT,FT, filename="output/src.html")(q0,p0[0],Kv,src_opacity=1,tgt_opacity=0,def_opacity=0, showgrid=False)
# PlotRes3D(VS,FS,VT,FT)(q0,p0[0],Kv)

In [14]:
# plot source shape
PlotRes3D(VS,FS,VT,FT, filename="output/src.html")(q0,p0[0],Kv,src_opacity=1,tgt_opacity=0,def_opacity=0, showgrid=False)

# plot target shape
PlotRes3D(VS,FS,VT,FT, filename="output/tgt.html")(q0,p0[0],Kv,src_opacity=0,tgt_opacity=1,def_opacity=0, showgrid=False)

# plot deformed shape
PlotRes3D(VS,FS,VT,FT, filename="output/deformation.html")(q0,p0[0],Kv,src_opacity=0,tgt_opacity=0,def_opacity=1, showgrid=False)

### Iterative LDDMM

The approach:

for t=0 to nt(timesteps):

1. Full 3D Deformation from Sorce (S) to target (T)
2. Get mesh at t+1 (M*)
3. Extract midsagittal coordinates from M*
4. Perform point matching between midsagittal coordinates and rtMRI coordinates (phi)
5. Deform M* with phi to get the new mesh at t+1 (M')
6. t=t+1 and S=M'
7. Repeat

In [60]:
RTMRI_VOXEL_SIZE = (1.9178, 1.9178) 
VOLUMETRIC_VOXEL_SIZE = (1.6, 1.6)

def get_rtMRI_contours(points_rtmri, midsagittal_coords, R):
    points_rtmri_pixels = points_rtmri / RTMRI_VOXEL_SIZE
    points_midsagittal_pixels = midsagittal_coords / VOLUMETRIC_VOXEL_SIZE
    
    points_rtmri_pixels = points_rtmri_pixels - np.mean(points_rtmri_pixels, axis=0)
    points_midsagittal_pixels = points_midsagittal_pixels - np.mean(points_midsagittal_pixels, axis=0)
    
    points_rtmri_rigid = (R @ points_rtmri_pixels.T).T + np.mean(points_midsagittal_pixels, axis=0)
    
    points_rtmri_physical = points_rtmri_rigid * VOLUMETRIC_VOXEL_SIZE
    
    return points_rtmri_physical

def plot_vt_boundary(coords, title, show_grid, fontsize, show_axis):
    plt.figure(figsize=(6, 6))
    plt.scatter(coords[:, 0], coords[:, 1], s=1, color='black')
    plt.title(title, fontsize=fontsize)
    if not show_axis:
        plt.axis('off')
    if show_grid:
        plt.grid(True)
    plt.show()

def extract_midsagittal_slice(mesh):
    # Center the mesh
    mesh_center = mesh.vertices.mean(axis=0)
    vertices = mesh.vertices

    # Compute distances from the x=0 plane
    distances = np.abs(vertices[:, 0])

    # Define threshold for selecting closest vertices
    threshold = 1.5
    closest_idxs = np.where(distances <= threshold)[0]

    # Extract the closest vertices
    closest_vertices = vertices[closest_idxs]

    # Project vertices onto the x=0 plane
    projected_vertices = closest_vertices.copy()
    projected_vertices[:, 0] = 0

    # Extract y and z coordinates
    coords = projected_vertices[:, 1:3]

    # Adjust x-coordinates for mirroring
    x_min, x_max = coords[:, 0].min(), coords[:, 0].max()
    coords[:, 0] = -(coords[:, 0] - x_max) - x_min

    # Center the coordinates
    coords[:, 0] -= coords[:, 0].mean()
    coords[:, 1] -= coords[:, 1].mean()

    plot_vt_boundary(coords, '3D Midsagittal VT Boundary', show_grid=True, fontsize=14, show_axis=True)
    
    return coords

In [None]:
# Main iterative algorithm

nt = 2 # number of deformation time steps
niter = 1 # number of optimization iterations per full deformation

# Load the meshes (source: vowel, target: consonant)
VS, FS = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_a_mesh_500.ply')
VT, FT = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_r_mesh_500.ply')

torch.save((VS, FS, VT, FT), 'vt_tensors/a-r_mesh.pt')
VS,FS,VT,FT = torch.load('vt_tensors/a-r_mesh.pt')

# Initialize the source mesh (S) and define the kernel.
current_source = VS.clone().detach().requires_grad_(True)
Kv = GaussKernel(sigma=20)

# Set up the data attachment loss 
Dataloss = lossVarifoldSurf(FS, VT, FT, GaussLinKernel(sigma=20))

rtMRI_contours = [] #################### to be implemented
R = np.load("rigid_rotation_matrix.npy") 

for t in range(nt):
    print(f"\n=== Outer Iteration {t+1}/{nt} ===")

    # 1. Full deformation from current_source (S) to target.
    loss_func = LDDMMloss(current_source, Kv, Dataloss, gamma=0.9)

    # 2. Optimize the initial momentum p0 for the registration.
    p0 = torch.zeros_like(current_source, requires_grad=True)
    p0_opt, opt_losses = Optimize(loss_func, [p0], niter=niter)

    # 3. Compute the deformed mesh (M*) using the optimized momentum.
    p_deformed, M_star = Shooting(p0_opt[0], current_source, Kv, deltat=1/nt)

    # 4. Extract midsagittal coordinates from M*
    M_star_mesh = trimesh.PointCloud(M_star.detach().numpy())
    midsagittal_coords = extract_midsagittal_slice(M_star_mesh)

    # 5. Perform point matching (phi deformation) between midsagittal coords and the rtMRI contour.
    rtMRI_contour = get_rtMRI_contours(rtMRI_contours[t], midsagittal_coords, R)

    # plt.figure(figsize=(6, 6))

    # plt.scatter(midsagittal_coords[:, 0], midsagittal_coords[:, 1], color='blue', label="Midsagittal Boundary", s=25)
    # plt.scatter(rtMRI_contour[:, 0], rtMRI_contour[:, 1], color='red', label="Transformed rtMRI Boundary", s=25)
    
    # plt.legend()
    # plt.title("Alignment of Vocal Tract Boundaries")
    # plt.xlabel("X Coordinate")
    # plt.ylabel("Y Coordinate")
    # plt.show()
    
    phi = MatchPoints(midsagittal_coords, rtMRI_contour, sigma=20)

    # 6. Apply phi to deform M* and get the new mesh M′.
    M_prime = deform_mesh(M_star, phi)

    # 7. Update current_source with M′ for the next iteration.
    current_source = M_prime.clone().detach().requires_grad_(True)

    print(f"Iteration {t+1} complete. (Updated mesh has {current_source.shape[0]} vertices)")

