In [1]:
import torch
import trimesh
import plotly.graph_objs as go

In [2]:
def plot_points(points, fig=go.Figure(), shift=[0,0,0], rowcol=None, camera=None):
    marker=dict(
            color='LightSkyBlue',
            size=2,
            line=dict(
                color='MediumPurple',
                width=2
            )
        )
    figpoints = go.Scatter3d(x=points[:,0]+shift[0], 
                             y=points[:,1]+shift[1],
                             z=points[:,2]+shift[2],
                             mode='markers', marker=marker)
    if rowcol is None:
        fig.add_trace(figpoints)
    else:
        fig.append_trace(figpoints, row=rowcol[0], col=rowcol[1])
    return fig

# plot mesh using plotly 3D plotting
def plot_mesh(mesh, faces, points=[], color="blue", opacity=1, fig=go.Figure(), rowcol=None, show_points=True, shift=[0,0,0]):
    V,F = mesh, faces
    figmesh = go.Mesh3d(x=V[:,0]+shift[0], 
                        y=V[:,1]+shift[1], 
                        z=V[:,2]+shift[2], 
                        i=F[:,0], j=F[:,1], k=F[:,2], color=color, opacity=opacity)
    if rowcol is None:
        fig.add_trace(figmesh)
    else:
        fig.append_trace(figmesh, row=rowcol[0], col=rowcol[1])

    if len(points) > 0:
        fig = plot_points(points, fig, shift=shift, rowcol=rowcol)
        
    return fig

In [None]:
def load_mesh_as_tensors(filepath):
    # Load the mesh
    mesh = trimesh.load(filepath, process=False)
    
    # Convert vertices and faces to PyTorch tensors
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
    faces = torch.tensor(mesh.faces, dtype=torch.long)
    
    return vertices, faces

In [None]:
VS, FS = load_mesh_as_tensors('data/surface_meshes/segmentation_a_remesh.ply')
VT, FT = load_mesh_as_tensors('data/surface_meshes/segmentation_r_remesh.ply')

torch.save((VS, FS, VT, FT), 'data/vt_tensors/a-r_mesh.pt')
VS,FS,VT,FT = torch.load('data/vt_tensors/a-r_mesh.pt')

phi_mesh_list_1 = torch.load('data/deformed_meshes_sigma_15.pt')
fig = go.Figure()
plot_mesh(phi_mesh_list_1[-1], FS, color="red", fig=fig, opacity=.5)