In [3]:
import numpy as np
import trimesh
import pyvista as pv
from scipy.spatial import KDTree
import torch
import torch.nn as nn
import torch.optim as optim
import os

In [4]:
source_mesh = "TractProjection/segmentation_aa_reduced.ply"
target_mesh = "TractProjection/segmentation_ll_reduced.ply"

In [7]:
class KernelFunction:
    def __init__(self, sigma=1.0):
        self.sigma = sigma
    
    def gaussian_kernel(self, x1, x2):
        """计算两组点之间的高斯核"""
        dist = torch.cdist(x1, x2)
        return torch.exp(-dist**2 / (2 * self.sigma**2))

In [8]:
class LDDMM:
    def __init__(self, source_mesh, target_mesh, num_timesteps=10, sigma=1.0, alpha=1.0):
        # 加载网格
        print("Loading meshes...")
        self.source = trimesh.load_mesh(source_mesh)
        self.target = trimesh.load_mesh(target_mesh)
        
        print(f"\nSource mesh: {len(self.source.vertices)} vertices, {len(self.source.faces)} faces")
        print(f"Target mesh: {len(self.target.vertices)} vertices, {len(self.target.faces)} faces")
        
        self.num_timesteps = num_timesteps
        self.kernel = KernelFunction(sigma)
        self.alpha = alpha
        
        # 转换为torch张量
        self.source_vertices = torch.tensor(self.source.vertices, dtype=torch.float32)
        self.target_vertices = torch.tensor(self.target.vertices, dtype=torch.float32)
        
        # 初始化动量向量
        self.momentum = nn.Parameter(torch.zeros_like(self.source_vertices))
    
    def forward_flow(self):
        """计算顶点的前向流动"""
        vertices = self.source_vertices.clone()
        trajectories = [vertices.clone()]
        
        for t in range(self.num_timesteps):
            # 使用核函数计算速度场
            velocity = self.kernel.gaussian_kernel(vertices, vertices) @ self.momentum
            # 使用欧拉积分更新位置
            vertices = vertices + velocity * (1.0 / self.num_timesteps)
            trajectories.append(vertices.clone())
        
        return trajectories
    
    def compute_loss(self, final_vertices):
        """计算变形后的网格与目标网格之间的损失"""
        # 计算最终状态与目标之间的距离
        matching_loss = torch.mean(torch.min(torch.cdist(final_vertices, self.target_vertices), dim=1)[0])
        # 动量的正则化
        reg_loss = self.alpha * torch.sum(self.momentum**2)
        return matching_loss + reg_loss
    
    def optimize(self, num_iterations=100, learning_rate=0.1):
        """优化动量向量"""
        optimizer = optim.Adam([self.momentum], lr=learning_rate)
        
        for i in range(num_iterations):
            optimizer.zero_grad()
            trajectories = self.forward_flow()
            final_vertices = trajectories[-1]
            loss = self.compute_loss(final_vertices)
            loss.backward()
            optimizer.step()
            
            if (i + 1) % 10 == 0:
                print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item():.4f}')
        
        return trajectories

In [9]:
def create_interpolated_mesh(vertices, faces, output_file):
    """创建并保存插值网格"""
    mesh = trimesh.Trimesh(vertices=vertices.detach().numpy(), faces=faces)
    mesh.export(output_file)
    return mesh

In [10]:
def morph_vocal_tracts(source_mesh_path, target_mesh_path, output_prefix, num_frames=10):
    """执行声道形状变形的主函数"""
    print("Initializing LDDMM...")
    lddmm = LDDMM(source_mesh_path, target_mesh_path)
    
    print("\nOptimizing transformation...")
    trajectories = lddmm.optimize()
    
    print("\nCreating interpolated meshes...")
    source_mesh = trimesh.load_mesh(source_mesh_path)
    faces = source_mesh.faces
    
    interpolated_meshes = []
    for i, vertices in enumerate(trajectories):
        output_file = f"{output_prefix}_frame_{i:03d}.ply"
        mesh = create_interpolated_mesh(vertices, faces, output_file)
        interpolated_meshes.append(mesh)
        print(f"Created frame {i+1}/{len(trajectories)}")
    
    return interpolated_meshes

In [11]:
def visualize_morphing(meshes):
    """使用PyVista可视化变形序列"""
    print("Visualizing morphing sequence...")
    plotter = pv.Plotter()
    
    for i, mesh in enumerate(meshes):
        print(f"Showing frame {i+1}/{len(meshes)}")
        pv_mesh = pv.wrap(mesh)
        plotter.add_mesh(pv_mesh, show_edges=True)
        plotter.show()
        plotter.clear()

In [12]:
source_mesh = "TractProjection/segmentation_aa_reduced.ply"
target_mesh = "TractProjection/segmentation_ll_reduced.ply"
output_prefix = "morphing_result"

In [13]:
morphed_meshes = morph_vocal_tracts(source_mesh, target_mesh, output_prefix)

Initializing LDDMM...
Loading meshes...

Source mesh: 152 vertices, 300 faces
Target mesh: 148 vertices, 300 faces

Optimizing transformation...
Iteration 10/100, Loss: 5.9000
Iteration 20/100, Loss: 4.7440
Iteration 30/100, Loss: 4.7278
Iteration 40/100, Loss: 4.7363
Iteration 50/100, Loss: 4.7100
Iteration 60/100, Loss: 4.6963
Iteration 70/100, Loss: 4.6907
Iteration 80/100, Loss: 4.6890
Iteration 90/100, Loss: 4.6889
Iteration 100/100, Loss: 4.6889

Creating interpolated meshes...
Created frame 1/11
Created frame 2/11
Created frame 3/11
Created frame 4/11
Created frame 5/11
Created frame 6/11
Created frame 7/11
Created frame 8/11
Created frame 9/11
Created frame 10/11
Created frame 11/11


Visualizing morphing sequence...
Showing frame 1/11


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 2/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 3/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 4/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 5/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 6/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 7/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 8/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 9/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 10/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…

Showing frame 11/11
A view with name (P_0x25788c370d0_1) is already registered
 => returning previous one


Widget(value='<iframe src="http://localhost:14584/index.html?ui=P_0x25788c370d0_1&reconnect=auto" class="pyvis…