In [4]:
import numpy as np
import torch
import trimesh
import pyvista as pv
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import grad

plt.rcParams['figure.figsize'] = [16, 10]

In [5]:
# Optional:

# KeOps library for kernel convolutions -- useless for small datasets
#!pip install pykeops 
use_keops = False # use of 

# pyvista for displaying 3D graphics
#!pip install pyvista[all]
use_pyvista = True

if use_keops:
    from pykeops.torch import LazyTensor

### Kernel Functions

In [6]:
# Gaussian Kernel (K(x,y)b)_i = sum_j exp(-|xi-yj|^2/sigma^2)bj
def GaussKernel(sigma):
    oos2 = 1/sigma**2
    def K(x,y,b):
        x,y = x[:,None,:],y[None,:,:]
        if use_keops:
            x,y = LazyTensor(x),LazyTensor(y)
        return (-oos2*((x-y)**2).sum(dim=2)).exp()@b
    return K

# Cauchy (K(x,y)b)_i = sum_j (1/(1+|xi-yj|^2/sigma^2))bj
def CauchyKernel(sigma):
    oos2 = 1/sigma**2
    def K(x,y,b):
        return (1/(1+oos2*torch.sum((x[:,None,:]-y[None,:,:])**2,dim=2)))@b
    return K

# kernel with multiple sigmas
def SumKernel(*kernels):
    def K(*args):
        return sum(k(*args) for k in kernels)
    return K

# defines a composite kernel that combines a Gaussian (radial basis function) kernel with a 
# linear kernel, tailored for use with vectors representing (measure similarity) geometric  
# entities (such as normals or directions) in addition to positions.
def GaussLinKernel(sigma,lib="keops"):
    oos2 = 1/sigma**2
    
    def K(x,y,u,v,b):
        # calculates the similarity based on the Euclidean distance between points  
        # x and y in a high-dimensional space
        Kxy = torch.exp(-oos2*torch.sum((x[:,None,:]-y[None,:,:])**2,dim=2))
        # computes the squared dot product between corresponding vectors u and v 
        # associated with points x and y, respectively. This part captures the 
        # similarity in directions (e.g., surface normals) at the points
        Sxy = torch.sum(u[:,None,:]*v[None,:,:],dim=2)**2
        
        # composite kernel
        return (Kxy*Sxy)@b
    return K

### Ordinary Differential Equations (ODEs) solver and Optimizer

In [7]:
# numerical integrator for solving ordinary differential equations (ODEs)
# solves the Hamiltonian system dynamics during the shape deformation process
def RalstonIntegrator(nt=10):
    # nt:  number of time steps to divide the integration interval into
    def f(ODESystem,x0,deltat=1.0):
        # x0: initial conditions (p0, q0)
        # deltat: total integration time
        x = tuple(map(lambda x:x.clone(),x0))
        dt = deltat/nt
        for i in range(nt):
            # computes the derivatives (system dynamics) of the current state x
            xdot = ODESystem(*x)
            # temporary state xi = x + (2*dt/3)*xdot
            # predicts the system state after two-thirds of the time step, 
            # guided by the initial derivative
            xi = tuple(map(lambda x,xdot:x+(2*dt/3)*xdot,x,xdot))
            # derivatives at this intermediate state xi
            xdoti = ODESystem(*xi)
            # final state for the time step is computed using the 
            # combination of the initial derivative and the intermediate derivative
            # weighted average of the initial and intermediate derivatives to estimate the next state
            x = tuple(map(lambda x,xdot,xdoti:x+(.25*dt)*(xdot+3*xdoti),x,xdot,xdoti))
            
        return x
    
    return f

# function to minimize the loss
def Optimize(loss,args,niter=5):
    # loss: function to compute the loss given the current set of parameters (args)
    # args: p0
    optimizer = torch.optim.LBFGS(args)
    losses = []
    print('performing optimization...')
    # repeatedly adjusts the parameters to minimize the loss function
    for i in range(niter):
        print("iteration ",i+1,"/",niter)
        def closure():
            # reset optimizer gradients to 0 (previous data doesn't affect the current update)
            optimizer.zero_grad()
            # compute loss
            L = loss(*args)
            losses.append(L.item())
            # backprop to compute the gradients of the loss w.r.t the parameters
            L.backward()
            
            return L
        
        # update optimizer parameters based on the loss
        optimizer.step(closure)
        
    print("Done.")

    return args, losses

def Optimize_with_vis(loss, args, niter=10):
    optimizer = torch.optim.LBFGS(args)
    losses = []
    print('Performing optimization with visualization...')
    for i in range(niter):
        print("Iteration", i+1, "of", niter)
        def closure():
            optimizer.zero_grad()
            L = loss(*args)
            losses.append(L.item())
            L.backward()
            return L
        optimizer.step(closure)

        # Create clones of current momentum and source that require grad,
        # so that Shooting (which calls autograd.grad) works correctly.
        p_vis = args[0].detach().clone().requires_grad_()
        q_vis = q0.detach().clone().requires_grad_()
        
        # Optionally, compute the deformed shape:
        p_current, q_current = Shooting(p_vis, q_vis, Kv)
        
        # Visualize using the PlotRes3D function:
        filename = f"deformation_iter_{i+1}.html"
        PlotRes3D(VS,FS,VT,FT, filename)(q_vis, p_vis, Kv, src_opacity=0, tgt_opacity=0, def_opacity=1, showgrid=False)
        
    print("Optimization done.")
    
    return args, losses

### Implementation of LDDMM

In [8]:
# function of momentum p and position q to represents the total energy of the system
# measures the energy associated with the deformation, using the kernel K 
# to mediate the influence of points on each other.
def Hamiltonian(K):
    def H(p,q):
        return .5*(p*K(q,q,p)).sum()
        
    return H

# builds the Hamiltonian system that needs to be solved during the "shooting" process
# calculates the gradients of the Hamiltonian with respect to p and q, 
# which represent the rates of change of these quantities
# the system is defined by -Gq, Gp (gradients)
def HamiltonianSystem(K):
    H = Hamiltonian(K)
    
    def HS(p,q):
        Gp,Gq = grad(H(p,q),(p,q), create_graph=True)
        
        return -Gq,Gp
        
    return HS

# integrates the Hamiltonian system over time to find the end state (p, q) 
# starting from initial conditions (p0, q0)
def Shooting(p0,q0,K,deltat=1.0,Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K),(p0,q0),deltat)

# intégration des équations de flot
def Flow(x0,p0,q0,K,deltat=1.0,Integrator=RalstonIntegrator()):
    HS = HamiltonianSystem(K)
    
    def FlowEq(x,p,q):
        return (K(x,q,p),)+HS(p,q)
        
    return Integrator(FlowEq,(x0,p0,q0),deltat)[0]

# defines the loss function to be minimized, 
# combining the Hamiltonian (energy of the deformation) 
# and a data attachment loss 
def LDDMMloss(q0,K,dataloss,gamma=0.):
    # dataloss: measures the discrepancy between the deformed source shape and the target shape
    # q0: initial configuration of the source shape
    # K: kernel function
    # gamma: regularization parameter
    def loss(p0):
        # finding p0 that minimizes the loss
        # p, q: final momentum and deformed shape after the shooting process
        p,q = Shooting(p0,q0,K)
        # Hamiltonian(K): computes the energy of the initial configuration q0 with the initial momentum p0. 
        # This represents the energy required to deform the shape regularized by gamma
        # dataloss(q): computes the mismatch between the final deformed shape q and the target
        return gamma * Hamiltonian(K)(p0,q0) + dataloss(q)
    
    ############################# ToDo #############################
    # when adding contours use the flow function, set deltat parameter
    # flow -> dataloss for curve -> flow
    # extract contours and get corresponding coordinates in 3D
    ###############################################################
    
    return loss

### Data attachment functions

In [9]:
# data attachment function for landmarks
def losslmk(z):
    def loss(q):
        return ((q-z)**2).sum()
    return loss

# data attachment function for point clouds via the measurement model
def lossmeas(z,Kw):
    nz = z.shape[0]
    wz = torch.ones(nz,1)
    cst = (1/nz**2)*Kw(z,z,wz).sum()
    def loss(q):
        nq = q.shape[0]
        wq = torch.ones(nq,1)
        return cst + (1/nq**2)*Kw(q,q,wq).sum() + (-2/(nq*nz))*Kw(q,z,wz).sum()
    return loss

# data attachment function for point clouds via regularized optimal transport
# (requires geomloss package)
def loss_OT(z):
    from geomloss import SamplesLoss
    loss_ = SamplesLoss()
    nz = z.shape[0]
    wz = torch.ones(nz,1)
    def loss(q):
        nq = q.shape[0]
        wq = torch.ones(nq,1)
        return loss_(wq,q,wz,z)
    return loss

# data attachment function for triangulated surfaces, varifold model
# focuses on matching geometric features like positions and normals without requiring explicit 
# point correspondence. this function is useful when the exact matching of points between shapes 
# is infeasible or not desired
def lossVarifoldSurf(FS,VT,FT,K):
    # VT: coordinates of the points of the target surface
    # FS,FT: indices of the triangles of the source and target surfaces
    # K: varifold kernel
    
    # compute centers (C), normals (N), and areas (L) of the triangles for a given surface
    # based on vertices V and faces F
    def CompCLNn(F,V):
        # V0, V1, V2: vertices of each triangle
        # center(C) of each triangle is calculated as the average of its vertices
        # normal(N): taking the cross product of two edges of the triangle, which is then normalized
        # area(L): length of the normal vector before normalization
        V0, V1, V2 = V.index_select(0,F[:,0]), V.index_select(0,F[:,1]), V.index_select(0,F[:,2])
        C, N = .5*(V0+V1+V2), .5*torch.linalg.cross(V1-V0,V2-V0)
        L = (N**2).sum(dim=1)[:,None].sqrt()
        
        return C,L,N/L
    
    CT,LT,NTn = CompCLNn(FT,VT)
    # self-interaction term for the target surface using the varifold kernel  K
    cst = (LT*K(CT,CT,NTn,NTn,LT)).sum()
    
    # calculates the varifold distance between the source and target surfaces
    
    # loss is formulated as the sum of the source self-interaction and the 
    # target self-interaction (cst), minus twice the cross term (interaction between their geometric features)
    # the intuition is to minimize the difference in geometric features (both position and direction of normals) 
    # between the source and target surfaces, thereby aligning them
    def loss(VS):
        CS,LS,NSn = CompCLNn(FS,VS)
        
        return cst + (LS*K(CS,CS,NSn,NSn,LS)).sum() - 2*(LS*K(CS,CT,NSn,NTn,LT)).sum()
    
    return loss

# data attachment function for curves, varifolds model
def lossVarifoldCurve(FS, VT, FT, K):
    def get_center_length_tangents(F, V):
        V0, V1 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1])
        centers, tangents = .5*(V0+V1), V1-V0
        length = (tangents**2).sum(dim=1)[:, None].sqrt()
        return centers, length, tangents / length
    CT, LT, TTn = get_center_length_tangents(FT, VT)
    cst = (LT * K(CT, CT, TTn, TTn, LT)).sum()
    def loss(VS):
        CS, LS, TSn = get_center_length_tangents(FS, VS)
        return cst + (LS * K(CS, CS, TSn, TSn, LS)).sum() - 2 * (LS * K(CS, CT, TSn, TTn, LT)).sum()
    return loss

### Plotting functions

In [10]:
# result display function for landmark or point cloud data
def PlotRes2D(z, pts=None):
    def plotfun(q0,p0,Kv, showgrid=True):
        p,q = Shooting(p0,q0,Kv)
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        q0np, qnp, znp = q0.data.numpy(), q.data.numpy(), z.data.numpy()
        plt.plot(znp[:,0],znp[:,1],'.');
        plt.plot(q0np[:,0],q0np[:,1],'+');
        plt.plot(qnp[:,0],qnp[:,1],'o');
        plt.axis('equal');
        if showgrid:
            X = get_def_grid(p0,q0,Kv)
            plt.plot(X[0],X[1],'k',linewidth=.25);
            plt.plot(X[0].T,X[1].T,'k',linewidth=.25); 
        n,d = q0.shape
        nt = 20
        Q = np.zeros((n,d,nt))
        for i in range(nt):
            t = i/(nt-1)
            Q[:,:,i] = Shooting(t*p0,q0,Kv)[1].data.numpy()
        plt.plot(Q[:,0,:].T,Q[:,1,:].T,'y');
        if type(pts)!=type(None):
            phipts = Flow(pts,p0,q0,Kv).data
            plt.plot(phipts.numpy()[:,0],phipts.numpy()[:,1],'.b',markersize=.1);
    return plotfun

# display function for triangulated surface type data
def PlotRes3D(VS, FS, VT, FT, filename="deformation.html"):
    def plotfun(q0, p0, Kv, src_opacity=1, tgt_opacity=1, def_opacity=1, showgrid=True):
        # q0, p0: Initial vertices and momenta (optimized) for the source shape
        # simulate the deformation process and compute the final deformed shape q
        p,q = Shooting(p0,q0,Kv)
        # q0np, qnp: numpy arrays of the initial and final vertex positions
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        # numpy arrays of source and target faces
        FSnp,VTnp, FTnp = FS.data.numpy(),  VT.data.numpy(), FT.data.numpy() 
        if use_pyvista:
            p = pv.Plotter()
            opacity = 1
            # mesh for the initial shape
            p.add_mesh(surf_to_pv(q0np,FSnp), color='lightblue', opacity=src_opacity)
            # mesh for the deformed shape
            p.add_mesh(surf_to_pv(qnp,FSnp), color='lightcoral', opacity=def_opacity)
            # mesh for target shape
            p.add_mesh(surf_to_pv(VTnp,FTnp), color='lightgreen', opacity=tgt_opacity)
            if showgrid:
                ng = 20
                X = get_def_grid(p0,q0,Kv,ng=ng)
                for k in range(3):
                    for i in range(ng):
                        for j in range(ng):
                            p.add_mesh(lines_from_points(X[:,i,j,:].T))
                    X = X.transpose((0,2,3,1))
            # p.show()
            p.export_html(filename)
            p.close()
        else:
            fig = plt.figure();
            plt.axis('off')
            plt.title('LDDMM matching example')     
#             ax = Axes3D(fig, auto_add_to_figure=False)
            ax = Axes3D(fig)
            # triangular mesh for the initial shape
            ax.plot_trisurf(q0np[:,0],q0np[:,1],q0np[:,2],triangles=FSnp,alpha=.5)
            # triangular mesh for the deformed shape
            ax.plot_trisurf(qnp[:,0],qnp[:,1],qnp[:,2],triangles=FSnp,alpha=.5)
            # triangular mesh for the target shape
            ax.plot_trisurf(VTnp[:,0],VTnp[:,1],VTnp[:,2],triangles=FTnp,alpha=.5)
            if showgrid:
                ng = 20
                X = get_def_grid(p0,q0,Kv,ng=ng)
                for k in range(3):
                    for i in range(ng):
                        for j in range(ng):
                            ax.plot(X[0,i,j,:],X[1,i,j,:],X[2,i,j,:],'k',linewidth=.25);
                    X = X.transpose((0,2,3,1))
            fig.add_axes(ax)
    return plotfun

def get_def_grid(p0,q0,Kv,ng=50):
    d = p0.shape[1]
    p,q = Shooting(p0,q0,Kv)
    q0np, qnp = q0.data.numpy(), q.data.numpy()
    q0np, qnp = q0.data.numpy(), q.data.numpy()
    # calculates the minimum (a) and maximum (b) coordinates for the vertices in the 
    # initial (q0) and final (q) positions to establish the bounds of the grid
    a = list(np.min(np.vstack((q0np[:,k],qnp[:,k]))) for k in range(d))
    b = list(np.max(np.vstack((q0np[:,k],qnp[:,k]))) for k in range(d))
    # expands these bounds by 20% to ensure the grid extends slightly beyond the 
    # immediate area covered by the initial and deformed shapes
    sz = 0.2
    lsp = list(np.linspace(a[k]-sz*(b[k]-a[k]),b[k]+sz*(b[k]-a[k]),ng,dtype=np.float32) for k in range(d))
    X = np.meshgrid(*lsp)
    x = np.concatenate(list(X[k].reshape(ng**d,1) for k in range(d)),axis=1)
    # transform grid (x) according to the LDDMM mapping
    # Flow():  integrates how each point in the grid moves under the transformation 
    # defined by Kv, p0, and q0. This step effectively applies the diffeomorphic map to the entire grid
    phix = Flow(torch.from_numpy(x),p0,q0,Kv).detach().numpy()
    X = phix.transpose().reshape([d]+[ng]*d)
    return X

def lines_from_points(points):
    import pyvista as pv
    """Given an array of points, make a line set"""
    poly = pv.PolyData()
    poly.points = points
    cells = np.full((len(points) - 1, 3), 2, dtype=np.int_)
    cells[:, 1] = np.arange(0, len(points) - 1, dtype=np.int_)
    cells[:, 2] = np.arange(1, len(points), dtype=np.int_)
    poly.lines = cells
    return poly

def surf_to_pv(V,F):
    nf = F.shape[0]
    F = np.hstack((np.ones((nf,1),dtype="int")*3,F))
    F = F.flatten()
    surf = pv.PolyData(V,F)
    return surf

# fonction d'affichage pour des données de type surface triangulée
def PlotResSurf(VS,FS,VT,FT):
    def plotfun(q0,p0,Kv):
        fig = plt.figure();
        plt.axis('off')
        plt.title('LDDMM matching example')  
        p,q = Shooting(p0,q0,Kv)
        q0np, qnp = q0.data.numpy(), q.data.numpy()
        FSnp,VTnp, FTnp = FS.data.numpy(),  VT.data.numpy(), FT.data.numpy()    
        ax = Axes3D(fig, auto_add_to_figure=False)
        ax.plot_trisurf(q0np[:,0],q0np[:,1],q0np[:,2],triangles=FSnp,alpha=.5)
        ax.plot_trisurf(qnp[:,0],qnp[:,1],qnp[:,2],triangles=FSnp,alpha=.5)
        ax.plot_trisurf(VTnp[:,0],VTnp[:,1],VTnp[:,2],triangles=FTnp,alpha=.5)
        fig.add_axes(ax)
    return plotfun

### Helper functions

In [11]:
# Function to load mesh and convert to tensors
def load_mesh_as_tensors(filepath):
    # Load the mesh
    mesh = trimesh.load(filepath, process=False)
    
    # Convert vertices and faces to PyTorch tensors
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
    faces = torch.tensor(mesh.faces, dtype=torch.long)
    
    return vertices, faces

### Surface matching with LDDMM

In [13]:
VS, FS = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_a_mesh_500.ply')
VT, FT = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_r_mesh_500.ply')

# Save the tensors into a .pt file
torch.save((VS, FS, VT, FT), 'a-r_mesh.pt')

VS,FS,VT,FT = torch.load('a-r_mesh.pt') 
q0 = VS.clone().detach().requires_grad_(True)
Kv = GaussKernel(sigma=20)
# multiscale kernel
Dataloss = lossVarifoldSurf(FS,VT,FT,GaussLinKernel(sigma=20)) # try smaller values (15-20)
loss = LDDMMloss(q0,Kv,Dataloss, gamma=0.1)
p0 = torch.zeros(q0.shape, requires_grad=True)

p0, losses = Optimize_with_vis(loss,[p0],niter=10) # 100 iterations
PlotRes3D(VS,FS,VT,FT, filename="src.html")(q0,p0[0],Kv,src_opacity=1,tgt_opacity=0,def_opacity=0, showgrid=False)
# PlotRes3D(VS,FS,VT,FT)(q0,p0[0],Kv)

Performing optimization with visualization...
Iteration 1 of 10
Iteration 2 of 10
Iteration 3 of 10
Iteration 4 of 10
Iteration 5 of 10
Iteration 6 of 10
Iteration 7 of 10
Iteration 8 of 10
Iteration 9 of 10
Iteration 10 of 10
Optimization done.


In [14]:
# plot source shape
PlotRes3D(VS,FS,VT,FT, filename="src.html")(q0,p0[0],Kv,src_opacity=1,tgt_opacity=0,def_opacity=0, showgrid=False)

# plot target shape
PlotRes3D(VS,FS,VT,FT, filename="tgt.html")(q0,p0[0],Kv,src_opacity=0,tgt_opacity=1,def_opacity=0, showgrid=False)

# plot deformed shape
PlotRes3D(VS,FS,VT,FT, filename="deformation.html")(q0,p0[0],Kv,src_opacity=0,tgt_opacity=0,def_opacity=1, showgrid=False)

Widget(value='<iframe src="http://localhost:56162/index.html?ui=P_0x18d6592a300_0&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:56162/index.html?ui=P_0x18d5b3510d0_0&reconnect=auto" class="pyvis…

Widget(value='<iframe src="http://localhost:56162/index.html?ui=P_0x18d5c6fba40_0&reconnect=auto" class="pyvis…

### Constrained LDDMM

The approach:
1. Modify the integrator to store intermediate states (history)
2. Extract the midsagittal slice (2D) from the deformed shape
3. Apply the transformation
4. Compute a loss (say contour loss) that compares these transformed 2D points to the corresponding rtMRI contours. 
5. Combine the “contour loss” with the original LDDMM loss.

In [None]:
# Ralston Integrator with History: returns a list of (p,q) states.
def RalstonIntegratorWithHistory(nt=10):
    def f(ODESystem, x0, deltat=1.0):
        history = []
        # x0: initial conditions (p0, q0)
        # deltat: total integration time
        x = tuple([x.clone() for x in x0])
        history.append(x)  # store the initial state at t = 0.
        dt = deltat / nt
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple([xi_val + (2 * dt / 3) * xdot_val for xi_val, xdot_val in zip(x, xdot)])
            xdoti = ODESystem(*xi)
            x = tuple([xi_val + (0.25 * dt) * (xdot_val + 3 * xdoti_val)
                       for xi_val, xdot_val, xdoti_val in zip(x, xdot, xdoti)])
            history.append(x)
            
        return history  # list of states at each time step
        
    return f

# Compute and return the deformed shape history (q states).
def ShootingHistory(p0, q0, K, deltat=1.0, Integrator=RalstonIntegratorWithHistory()):
    history = Integrator(HamiltonianSystem(K), (p0, q0), deltat)
    # Extract the q (shape) component from each (p,q) state.
    q_history = [state[1] for state in history]
    
    return q_history

# Extract the midsagittal slice from the 3D shape.
def extract_midsagittal_slice(vertices):
    ## ToDo: follow the implementation done in MATLAB
    return slice_coords
    
# Apply transformations to the extracted 2D slice so that it matches the rtMRI frame coordinates.
def apply_transformation(points, transformation_matrix, translation_vector):
    ## ToDo: follow the Python implementation and apply transformation
    return trandformed_points

# Compute the contour loss over the history of deformed shapes.
# For each intermediate deformed shape, extract the midsagittal slice,
# apply the transformation, and compare with the corresponding rtMRI contour.
def contour_loss(q_history, rtMRI_contours, transformation_matrix, translation_vector):
    # q_history: list of deformed shapes (tensor of shape (N, 3))
    # rtMRI_contours: list of 2D contour tensors (each of shape (M, 2)) for each time step (ideal length = 11)
    total_loss = 0.0
    num = len(q_history)
    for q, contour_rt in zip(q_history, rtMRI_contours):
        slice_2d = extract_midsagittal_slice(q)
        slice_2d_transformed = apply_transformation(slice_2d, transformation_matrix, translation_vector)
        
        ## ToDo: Loss calculation when the number of points are different and no correspondance
        ## Chamfer distance?
        loss_i = ((slice_2d_transformed - contour_rt.to(slice_2d_trans.device))**2).mean() # Cannot use MSE
        
        total_loss += loss_i
        
    return total_loss / num

# Extended LDDMM loss function that includes the contour loss.
# Contour loss is weighted by a parameter lambda_contour
def LDDMMloss_extended(q0, K, dataloss, rtMRI_contours, transformation_matrix, translation_vector, gamma=0., lambda_contour=1.0):
    # q0: initial source shape (3D, tensor)
    # K: kernel function 
    # dataloss: function computing data attachment loss between deformed shape and target 3D shape
    # rtMRI_contours: list of rtMRI contour tensors each with shape (M,2)
    # transformation_matrix, translation_vector: parameters to align the extracted 2D slice with rtMRI coordinates
    def loss(p0):
        # Compute the final deformed shape using the original existing Shooting function.
        p, q_final = Shooting(p0, q0, K)
        base_loss = gamma * Hamiltonian(K)(p0, q0) + dataloss(q_final)
        
        # Compute the intermediate deformed shapes (history) using the modified integrator.
        q_history = ShootingHistory(p0, q0, K)
        contour_loss_val = contour_loss(q_history, rtMRI_contours, transformation_matrix, translation_vector)
        
        return base_loss + lambda_contour * contour_loss_val
        
    return loss

In [None]:
VS, FS = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_a_mesh_500.ply')
VT, FT = load_mesh_as_tensors('../../Data/surface_meshes/segmentation_r_mesh_500.ply')

# Save the tensors into a .pt file
torch.save((VS, FS, VT, FT), 'a-r_mesh.pt')

VS,FS,VT,FT = torch.load('a-r_mesh.pt') 
q0 = VS.clone().detach().requires_grad_(True)
Kv = GaussKernel(sigma=20)

rtMRI_contours = []
transformation_matrix = []
translation_vector = []

Dataloss = lossVarifoldSurf(FS,VT,FT,GaussLinKernel(sigma=20)) 
# loss = LDDMMloss(q0,Kv,Dataloss)
loss = LDDMMloss_extended(q0, Kv, dataloss, rtMRI_contours, transformation_matrix, translation_vector, gamma=0.01, lambda_contour=1.0)

p0 = torch.zeros(q0.shape, requires_grad=True)
p0, losses = Optimize(loss,[p0],niter=10)