In [1]:
import numpy as np
import trimesh
import pyvista as pv
from scipy.spatial import KDTree
import torch
import torch.nn as nn
import torch.optim as optim
import os
import matplotlib.pyplot as plt

In [2]:
source_mesh = "TractProjection/segmentation_aa_reduced_1000.ply"
target_mesh = "TractProjection/segmentation_oo_reduced_1000.ply"
output_prefix = "morphing_result"

In [3]:
import numpy as np
import trimesh
import pyvista as pv
from scipy.spatial import KDTree
import os
import matplotlib.pyplot as plt

class LDDMMMapper:
    def __init__(self, kernel_width=1.0, regularization=1.0):
        self.kernel_width = kernel_width
        self.regularization = regularization
        self.momentum = None
        self.control_points = None
        
    def gaussian_kernel(self, x, y):
        """ K(x,y) = exp(-|x-y|^2 / (2*sigma^2))"""
        dist_sq = np.sum((x[:, np.newaxis, :] - y[np.newaxis, :, :]) ** 2, axis=2)
        return np.exp(-dist_sq / (2 * self.kernel_width ** 2))
    
    def compute_geodesic(self, source_points, target_points, num_timesteps=10, max_iter=10):
        self.control_points = source_points.copy()
        num_points = len(source_points)
        
        # init
        self.momentum = np.zeros((num_points, 3))
        
        target_tree = KDTree(target_points)
        
        for iter in range(max_iter):
            # fw
            trajectory = self.integrate_forward(source_points)
            end_points = trajectory[-1]
            
            distances, indices = target_tree.query(end_points)
            current_target = target_points[indices]
            
            gradient = self.compute_gradient(trajectory, end_points, current_target)
            step_size = 0.1 / (iter + 1)
            self.momentum -= step_size * gradient
            loss = np.mean(distances**2) + self.regularization * self.compute_momentum_norm()
            
            # if iter % 10 == 0:
            print(f"Iteration {iter}, Loss: {loss:.6f}")
                    
        return trajectory
    
    def integrate_forward(self, points, num_steps=10):
        dt = 1.0 / num_steps
        current_points = points.copy()
        trajectory = [current_points.copy()]
        
        for _ in range(num_steps):
            # v
            K = self.gaussian_kernel(current_points, self.control_points)
            velocity = np.dot(K, self.momentum)
            
            # update position
            current_points = current_points + dt * velocity
            trajectory.append(current_points.copy())
            
        return trajectory
    
    def compute_gradient(self, trajectory, end_points, target_points):
        data_gradient = end_points - target_points
        
        K = self.gaussian_kernel(self.control_points, self.control_points)
        reg_gradient = 2 * self.regularization * np.dot(K, self.momentum)
        
        return data_gradient + reg_gradient
    
    def compute_momentum_norm(self):
        K = self.gaussian_kernel(self.control_points, self.control_points)
        return np.sum(self.momentum * np.dot(K, self.momentum))

def load_and_normalize_mesh(mesh_path):
    mesh = trimesh.load_mesh(mesh_path)
    vertices = mesh.vertices
    
    # norm
    center = np.mean(vertices, axis=0)
    vertices = vertices - center
    scale = np.max(np.abs(vertices))
    vertices = vertices / scale
    
    return vertices, mesh.faces, center, scale

In [None]:
import numpy as np
import trimesh
import pyvista as pv
import matplotlib.pyplot as plt
from matplotlib import animation
from scipy.spatial import KDTree
import imageio

def convert_faces_to_pyvista(faces):
    """Convert faces from trimesh format to PyVista format"""
    if faces.shape[1] == 3:
        faces_pv = np.hstack((np.ones((faces.shape[0], 1), dtype=np.int64) * 3, 
                             faces))
        return faces_pv.ravel()
    return None

def load_and_normalize_mesh(mesh_path):
    """Load and normalize mesh"""
    mesh = trimesh.load_mesh(mesh_path)
    vertices = mesh.vertices
    center = np.mean(vertices, axis=0)
    vertices = vertices - center
    scale = np.max(np.abs(vertices))
    vertices = vertices / scale
    return vertices, mesh.faces, center, scale

def compute_point_errors(source_points, target_points):
    """Calculate distances from points to their nearest target points"""
    target_tree = KDTree(target_points)
    distances, _ = target_tree.query(source_points)
    return distances

def create_3d_grid_visualization(mapper, trajectory, output_prefix, grid_size=20):
    """Create 3D deformation grid visualization"""
    # Create 3D uniform grid
    x = np.linspace(-1, 1, grid_size)
    y = np.linspace(-1, 1, grid_size)
    z = np.linspace(-1, 1, grid_size)
    X, Y, Z = np.meshgrid(x, y, z)
    grid_points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
    
    # Create grid connectivity
    grid_lines = []
    # Lines in X direction
    for i in range(grid_size):
        for j in range(grid_size):
            for k in range(grid_size-1):
                idx1 = i * grid_size * grid_size + j * grid_size + k
                idx2 = i * grid_size * grid_size + j * grid_size + k + 1
                grid_lines.append([idx1, idx2])
    # Lines in Y direction
    for i in range(grid_size):
        for j in range(grid_size-1):
            for k in range(grid_size):
                idx1 = i * grid_size * grid_size + j * grid_size + k
                idx2 = i * grid_size * grid_size + (j+1) * grid_size + k
                grid_lines.append([idx1, idx2])
    # Lines in Z direction
    for i in range(grid_size-1):
        for j in range(grid_size):
            for k in range(grid_size):
                idx1 = i * grid_size * grid_size + j * grid_size + k
                idx2 = (i+1) * grid_size * grid_size + j * grid_size + k
                grid_lines.append([idx1, idx2])
    
    grid_lines = np.array(grid_lines)
    
    # Create off-screen renderer
    pv.OFF_SCREEN = True
    
    # Define camera views
    views = {
        'front': {'position': (0, -2, 0), 'up': (0, 0, 1)},
        'side': {'position': (2, 0, 0), 'up': (0, 0, 1)},
        'top': {'position': (0, 0, 2), 'up': (0, 1, 0)},
        'isometric': {'position': (2, -2, 2), 'up': (0, 0, 1)}
    }
    
    print("Creating 3D grid animation...")
    frames = []
    
    for frame in range(len(trajectory)):
        print(f"Processing frame {frame}/{len(trajectory)-1}")
        
        # Calculate current frame grid points position
        current_grid = grid_points.copy()
        for _ in range(frame):
            K = mapper.gaussian_kernel(current_grid, trajectory[0])
            velocity = np.dot(K, mapper.momentum)
            current_grid = current_grid + velocity * 0.1
        
        # Create visualization for current frame
        p = pv.Plotter(off_screen=True, window_size=[1024, 768])
        
        # Add deformed grid
        # Convert grid lines to PyVista format
        grid_cells = []
        for line in grid_lines:
            grid_cells.extend([2, line[0], line[1]])
        
        # Create grid object
        grid_poly = pv.PolyData(current_grid, lines=grid_cells)
        
        # Add source and target meshes
        source_mesh = pv.PolyData(trajectory[frame], source_faces_pv)
        target_mesh = pv.PolyData(target_vertices, target_faces_pv)
        
        # Set camera view
        p.camera.position = views['isometric']['position']
        p.camera.up = views['isometric']['up']
        p.camera.focal_point = (0, 0, 0)
        
        # Add meshes to scene
        p.add_mesh(grid_poly, color='black', line_width=1, opacity=0.3)
        p.add_mesh(source_mesh, color='red', opacity=0.5)
        p.add_mesh(target_mesh, color='blue', opacity=0.3)
        
        # Add frame counter
        p.add_text(f'Frame: {frame}/{len(trajectory)-1}', position='upper_left', font_size=14)
        
        # Render and save frame
        p.show(auto_close=False)
        frames.append(p.screenshot(return_img=True))
        p.close()
    
    # Save animation
    print("Saving animation...")
    imageio.mimsave(f'{output_prefix}_3d_grid.gif', frames, fps=10)

def save_static_views(mapper, trajectory, output_prefix):
    """Save static views from different angles"""
    views = {
        'front': {'position': (0, -2, 0), 'up': (0, 0, 1)},
        'side': {'position': (2, 0, 0), 'up': (0, 0, 1)},
        'top': {'position': (0, 0, 2), 'up': (0, 1, 0)},
        'isometric': {'position': (2, -2, 2), 'up': (0, 0, 1)}
    }
    
    for view_name, camera_params in views.items():
        p = pv.Plotter(off_screen=True, window_size=[1024, 768])
        
        # Set camera
        p.camera.position = camera_params['position']
        p.camera.up = camera_params['up']
        p.camera.focal_point = (0, 0, 0)
        
        # Add final deformation result
        source_mesh = pv.PolyData(trajectory[-1], source_faces_pv)
        target_mesh = pv.PolyData(target_vertices, target_faces_pv)
        
        p.add_mesh(source_mesh, color='red', opacity=0.5)
        p.add_mesh(target_mesh, color='blue', opacity=0.3)
        
        p.add_text(f'{view_name.capitalize()} View', position='upper_left', font_size=14)
        p.screenshot(f'{output_prefix}_{view_name}_view.png')
        p.close()

def visualize_transformation(source_path, target_path, output_prefix):
    """Main function: Execute LDDMM transformation and create visualization"""
    global source_faces_pv, target_vertices, target_faces_pv
    
    # Load original data
    source_vertices, source_faces, source_center, source_scale = load_and_normalize_mesh(source_path)
    target_vertices, target_faces, target_center, target_scale = load_and_normalize_mesh(target_path)
    
    print(f"Source vertices: {len(source_vertices)}")
    print(f"Target vertices: {len(target_vertices)}")
    
    # Create LDDMM mapper and compute deformation trajectory
    mapper = LDDMMMapper(kernel_width=0.2, regularization=0.01)
    trajectory = mapper.compute_geodesic(source_vertices, target_vertices)
    
    # Convert face format (convert in advance to avoid repeated computation)
    source_faces_pv = convert_faces_to_pyvista(source_faces)
    target_faces_pv = convert_faces_to_pyvista(target_faces)
    
    # Create 3D grid visualization
    create_3d_grid_visualization(mapper, trajectory, output_prefix)
    
    # Save static views from different angles
    save_static_views(mapper, trajectory, output_prefix)
    
    return trajectory, source_vertices, target_vertices

if __name__ == "__main__":
    source_mesh = "TractProjection/segmentation_aa_reduced_1000.ply"
    target_mesh = "TractProjection/segmentation_oo_reduced_1000.ply"
    output_prefix = "morphing_result"
    
    trajectory, source_vertices, target_vertices = visualize_transformation(
        source_mesh, target_mesh, output_prefix)

Source vertices: 501
Target vertices: 498
Iteration 0, Loss: 0.007641
Iteration 1, Loss: 0.004659
Iteration 2, Loss: 0.004352
Iteration 3, Loss: 0.004278
Iteration 4, Loss: 0.004272
Iteration 5, Loss: 0.004293
Iteration 6, Loss: 0.004308
Iteration 7, Loss: 0.004326
Iteration 8, Loss: 0.004344
Iteration 9, Loss: 0.004362
Creating 3D grid animation...
Processing frame 0/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2523b1e4970_0&reconnect=auto" class="pyvist…

Processing frame 1/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x252299bcfd0_0&reconnect=auto" class="pyvist…

Processing frame 2/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x25229a32b20_0&reconnect=auto" class="pyvist…

Processing frame 3/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x252278e16a0_0&reconnect=auto" class="pyvist…

Processing frame 4/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522773d1c0_0&reconnect=auto" class="pyvist…

Processing frame 5/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522789de80_0&reconnect=auto" class="pyvist…

Processing frame 6/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522d089700_0&reconnect=auto" class="pyvist…

Processing frame 7/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522f3c4d90_0&reconnect=auto" class="pyvist…

Processing frame 8/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522f45feb0_0&reconnect=auto" class="pyvist…

Processing frame 9/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522f502ee0_0&reconnect=auto" class="pyvist…

Processing frame 10/10


Widget(value='<iframe src="http://localhost:6029/index.html?ui=P_0x2522f45ff40_0&reconnect=auto" class="pyvist…

Saving animation...
