In [None]:
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI')
from shapeaxi import utils
import pandas as pd

In [None]:
device = torch.device("cuda:0")


In [None]:
target_fn = '/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/FAM-025-0499-5/brain/leftWhiteMatter.stl'
target = utils.ReadSurf(target_fn)
target, target_mean_bb, target_scale_factor = utils.ScaleSurf(target)
target_v, target_f, target_e = utils.PolyDataToTensors(target, device=device)
target_mesh = Meshes(verts=[target_v], faces=[target_f])

In [None]:
def plot_pointcloud(mesh, title=""):
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z)])
    fig.show()

In [None]:
plot_pointcloud(target_mesh)

In [None]:
source = utils.IcoSphere(6)
source_v, source_f, source_e = utils.PolyDataToTensors(source, device=device)
source_mesh = Meshes(verts=[source_v], faces=[source_f])

In [None]:
plot_pointcloud(source_mesh)

In [None]:
deform_verts = torch.full(source_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)

In [None]:
NPoints = source_v.shape[0]
# Number of optimization steps
Niter = 20000
# Weight for the chamfer loss
w_chamfer = 1.0 
# Weight for mesh edge loss
w_edge = 1.0 
# Weight for mesh normal consistency
w_normal = 0.01 
# Weight for mesh laplacian smoothing
w_laplacian = 0.1 
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))

chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    # Deform the mesh
    new_source_mesh = source_mesh.offset_verts(deform_verts)
    
    # We sample 5k points from the surface of each mesh 
    sample_target = sample_points_from_meshes(target_mesh, NPoints)
    # sample_source = sample_points_from_meshes(new_source_mesh, NPoints)
    
    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_target, new_source_mesh.verts_packed().unsqueeze(0))
    
    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_source_mesh)
    
    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_source_mesh)
    
    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_source_mesh, method="uniform")
    
    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
    
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    chamfer_losses.append(float(loss_chamfer.detach().cpu()))
    edge_losses.append(float(loss_edge.detach().cpu()))
    normal_losses.append(float(loss_normal.detach().cpu()))
    laplacian_losses.append(float(loss_laplacian.detach().cpu()))
        
    # Optimization step
    loss.backward()
    optimizer.step()

In [None]:
df = pd.DataFrame({
    "chamfer": chamfer_losses,
    "edge": edge_losses,
    "normal": normal_losses,
    "laplacian": laplacian_losses,
})
px.line(df, title="Losses")


In [None]:

target_mean_bb, target_scale_factor

new_source_mesh_v = (new_source_mesh.verts_packed().detach().cpu())/target_scale_factor + target_mean_bb

new_source_mesh_surf = utils.TensorToPolyData(new_source_mesh_v, new_source_mesh.faces_packed().detach().cpu())
utils.WriteSurf(new_source_mesh_surf, '/mnt/famli_netapp_shared/C1_ML_Analysis/leftWhiteMatter_fitted.stl')


In [None]:
mount_point = "/mnt/raid/C1_ML_Analysis"
df = pd.read_csv(os.path.join(mount_point, "simulated_data_export/rest_meshes_vtk.csv"))
print(df)