In [None]:
import os
import sys

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch3d
from pytorch3d import _C
from pytorch3d.io import load_obj, save_obj, load_objs_as_meshes
from pytorch3d.utils import ico_sphere, torus
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.ops import sample_points_from_meshes, knn_points, estimate_pointcloud_normals, knn_gather, cubify
from pytorch3d.loss.point_mesh_distance import point_mesh_edge_distance, point_mesh_face_distance
from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency
import trimesh

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

from ops.mesh_geometry import *

import pyvista as pv
pv.start_xvfb()
pv.set_jupyter_backend('html')

# # from pytorch3d.structures import Meshes

# from .utils import one_hot_sparse


In [None]:
device = torch.device("cuda:0")
mesh_trg = load_objs_as_meshes(["./data_example/Bull.obj"], device=device)

In [None]:
# mesh_tem = Meshes(verts=[torch.from_numpy(trimesh_tem.vertices).float()], 
#                   faces=[torch.from_numpy(trimesh_tem.faces).long()]).to(device)

mesh_tem = normalize_mesh(mesh_trg) # normalize the mesh to fit in the unit cube

print(mesh_tem.verts_packed().shape)

In [None]:
# voxelizer = Differentiable_Voxelizer(bbox_density=128)


sample_size = 64

meshbbox = mesh_tem.get_bounding_boxes()[0]

coordinates_downsampled = torch.stack(torch.meshgrid(torch.linspace(meshbbox[0,0], meshbbox[0,1], sample_size),
                                                        torch.linspace(meshbbox[1,0], meshbbox[1,1], sample_size),
                                                        torch.linspace(meshbbox[2,0], meshbbox[2,1], sample_size)), dim=-1)


coordinates_downsampled = coordinates_downsampled.view(-1, 3).to(device)

n_points = coordinates_downsampled.shape[0]


In [None]:
# current_mesh = torus(r=1, R=2, sides=64, rings=64).to(device)
current_mesh = ico_sphere(3, device=device)
current_mesh = normalize_mesh(current_mesh) #

In [None]:
# graph network
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.nn import NNConv, Set2Set, CGConv, ChebConv, global_mean_pool


In [None]:
class GraphNet(torch.nn.Module):
    def __init__(self, input_dim = 3, output_dim = 3):
        super(GraphNet, self).__init__()
        self.conv1 = ChebConv(input_dim, 512, K=3)
        self.conv2 = ChebConv(512, 1024, K=3)
        self.conv3 = ChebConv(1024+input_dim, output_dim, K=3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.silu(x)
        # x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.silu(x)
        # x = F.dropout(x, training=self.training)
        x = torch.cat([x, data.pos], dim=-1)
        x = self.conv3(x, edge_index)
        return x

In [None]:
def dice_loss(input, target):
    smooth = 1e-5
    input = input.sigmoid()
    return 1 - (2 * (input * target).sum() + smooth) / (input.sum() + target.sum() + smooth)

In [None]:
with torch.no_grad():
    # occp_gt = arctan_occp_gt(pt_target)
    # occp_gt = torch.where(occp_gt > 0.5, torch.ones_like(occp_gt), torch.zeros_like(occp_gt))

    # dist_gt, _ = _C.point_face_dist_forward(pt_target.view(-1, 3).to(device),
    #                 torch.tensor([0], device=device, dtype=torch.int64),
    #                 mesh_tem.verts_packed()[mesh_tem.faces_packed(),:].to(device),
    #                 torch.tensor([0], device=device, dtype=torch.int64),
    #                 n_points, 1e-5)
    occp_gt = occupancy(mesh_tem, coordinates_downsampled, allow_grad=False)

In [None]:
# network = GraphNet(input_dim=3, output_dim=3).to(device)

offset = torch.zeros_like(current_mesh.verts_packed(), requires_grad=True).to(device)

optimizer = torch.optim.AdamW([offset], lr=1e-2)

arctan_occp_gt = arctan_det_occp(mesh_tem)

n_iter = 1000
for i in range(n_iter):
    optimizer.zero_grad()
    # data = Data(x=current_mesh.verts_packed(), pos=current_mesh.verts_packed(), edge_index=current_mesh.edges_packed().T)
    # offset = network(data)
    new_mesh = current_mesh.offset_verts(offset)

   
    # arctan_occp = arctan_det_occp(new_mesh)



    # sample_trg = sample_points_from_meshes(mesh_tem, 5000)
    # sample_src = sample_points_from_meshes(new_mesh, 5000)

    # pt_target = torch.cat([sample_trg, sample_src], dim=-2)
    # pt_target = pt_target + torch.randn_like(pt_target)*0.5
    # pt_target = pt_target[0]
    
    # # We compare the two sets of pointclouds by computing (a) the chamfer loss
    # loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)

    # arctan_occp = arctan_det_occp(new_mesh)

    occp =  Winding_Occupancy_Face(new_mesh, coordinates_downsampled)
    # occp = torch.clamp(occp, 0, 1)

    # dist, _ = _C.point_face_dist_forward(pt_target.view(-1, 3).to(device),
    #             torch.tensor([0], device=device, dtype=torch.int64),
    #             new_mesh.verts_packed()[new_mesh.faces_packed(),:].to(device),
    #             torch.tensor([0], device=device, dtype=torch.int64),
    #             n_points, 1e-5)
    
    # sdf = 2*(0.5 - occp) * dist

    # sdf = signed_distance_field(new_mesh, pt_target, allow_grad=True)



    loss_sdf = dice_loss(occp, occp_gt)
    # loss_chamfer,_ = chamfer_distance(sample_trg, sample_src)
    loss_smooth = mesh_laplacian_smoothing(new_mesh)
    loss_normal = mesh_normal_consistency(new_mesh) 
    loss_edge = mesh_edge_loss(new_mesh)
    loss_normalize = torch.norm(offset, p=2)*1e-3
    loss =   loss_smooth*0.1 + loss_normal*0.1 + loss_edge*0.1  + loss_sdf + loss_normalize

    loss.backward()
    optimizer.step()
    print("Iter: ", i, " loss_sdf: ", loss_sdf.item())




In [None]:
pl = pv.Plotter(notebook=True)

trimesh_tem = trimesh.Trimesh(vertices=new_mesh.verts_list()[0].detach().cpu().numpy(), faces=new_mesh.faces_list()[0].detach().cpu().numpy())

pl.add_mesh(trimesh_tem, color='lightblue', show_edges=True, opacity=0.1)

trimesh_target = trimesh.Trimesh(vertices=mesh_tem.verts_list()[0].detach().cpu().numpy(), faces=mesh_tem.faces_list()[0].detach().cpu().numpy())

pl.add_mesh(trimesh_target, color='lightgreen', show_edges=True, opacity=0.1)


# pl.add_points(coordinates_downsampled.cpu().numpy()[np.random.choice(n_points, 1000)], color='red', point_size=5)

pl.camera.roll = 600
# pl.camera.elevation = 180
# pl.camera.azimuth = 180
pl.camera.zoom = 1.3

pl.show() #screenshot='out_exp/tem_mesh_2.png', window_size=[800,800])


In [None]:
arctan_occp = arctan_det_occp(mesh_tem)

# if u wanna use multi-gpus (but may not be faster)
arctan_occp = nn.DataParallel(arctan_occp, device_ids=[0,1,2,3])
arctan_occp = arctan_occp.cuda()
arctan_occp = arctan_occp.half()



dats_set = torch.utils.data.TensorDataset(coordinates_downsampled.half().cpu(), torch.arange(0, n_points).long())

dataloader = torch.utils.data.DataLoader(dats_set, batch_size=128**3, shuffle=False, num_workers=10, drop_last=False)

device = torch.device("cuda:0")

arctan_occp.eval()

occp_result = torch.zeros(n_points, dtype=torch.half, device=device)

opp_result = occp_result.half()


for i, (data, idx) in enumerate(dataloader): ### multi-gpu
    points = data.cuda()
    indx = idx.to(device)
    with torch.no_grad():
        occp = arctan_occp.forward(points, 2000)
        occp_result[indx] = occp.to(device)


In [None]:
dist, _ = _C.point_face_dist_forward(coordinates_downsampled.view(-1, 3).to(device),
                        torch.tensor([0], device=device, dtype=torch.int64),
                        mesh_tem.verts_packed()[mesh_tem.faces_packed(),:].to(device),
                        torch.tensor([0], device=device, dtype=torch.int64),
                        n_points, 1e-5)

occpfield = occp_result.view(1, sample_size, sample_size, sample_size)
occpfield = occpfield.permute(0, 3, 2, 1)

sdf = -torch.tanh((occp_result-0.5)*100)*dist

sdf_trg = sdf.view(1, sample_size, sample_size, sample_size)
sdf_trg = sdf_trg.permute(0, 3, 2, 1)



# plt.imshow(sdf_trg[0,80].detach().cpu().numpy())

In [None]:
# occp = torch.where(sdf_trg<0, torch.ones_like(sdf_trg), torch.zeros_like(sdf_trg))

# cubified = cubify(occp_result.view(1, sample_size, sample_size, sample_size), 0.5) # cubify the voxel grid, which is the inverse operation of voxelization

cubified = cubify(-sdf_trg, -1e-5) # cubify the voxel grid, which is the inverse operation of voxelization

cubified = cubified.update_padded(cubified.verts_padded()*(meshbbox[:,1].view(1,1, 3).to(device)))
# cubified = pytorch3d.ops.taubin_smoothing(cubified, 1, 0.1)

In [None]:
pl = pv.Plotter(notebook=True)
trimesh_cubified = trimesh.Trimesh(cubified.verts_packed().detach().cpu().numpy(), cubified.faces_packed().detach().cpu().numpy())
trimesh_cubified = trimesh.smoothing.filter_laplacian(trimesh_cubified, iterations=4)

trimesh_original = trimesh.Trimesh(mesh_tem.verts_list()[0].cpu().numpy(), mesh_tem.faces_list()[0].cpu().numpy())

# pl.add_mesh(trimesh_original, color='lightgreen', show_edges=True, opacity=0.2)

pl.add_mesh(trimesh_cubified, color='lightgreen', opacity=1, show_edges=False)

pl.camera.elevation = 140
pl.camera.azimuth = 60
pl.camera.zoom = 1.3

pl.show() #screenshot='out_exp/cubified_2.png', window_size=[800,800])
