In [1]:
from model.treemeshgpt_inference import TreeMeshGPT
import os
import numpy as np
import open3d as o3d
import torch
from accelerate import Accelerator
from pathlib import Path
from fns import center_vertices, normalize_vertices_scale, quantize_verts, dequantize_verts
import trimesh
import pyvista as pv
from utils.utils import GenerateAreaToRemesh, SaveAreaToRemeshInOBJ, ExtractRingsAroundTriangles, SaveOBJ
from tokenizer import prepare_halfedge_mesh
import random

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
VERSION = "7bit"
CKPT_PATH = "./checkpoints/treemeshgpt_7bit.pt"

OUTPUT_DIR="./output"

DECIMATION_TARGET_NFACES = 5000
SAMPLING = "uniform" if VERSION == "7bit" else "fps"

TORCH_DEVICE="cuda:1"

if not os.path.exists("./output") :
  os.mkdir("./output")

In [3]:
# #Set cuda device
# torch.device(TORCH_DEVICE)

# # Define the seed value
# seed = 42

# # Set seed for PyTorch
# torch.manual_seed(seed)

# # Set seed for CUDA (if using GPUs)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

# # Set seed for Python's random module
# random.seed(seed)

# # Set seed for NumPy
# np.random.seed(seed)

# # Ensure deterministic behavior for PyTorch operations
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

In [4]:
NB_SAMPLING_POINTS=8192

# MESH_PATH = "demo/Mesh2.obj"
# TRIANGLES_TO_REMESH=[ 107, 106, 105, 104, 103, 102,
#                       117, 116, 115, 114, 113, 112,
#                       127, 126, 125, 124, 123, 122 ]
# CONTEXT_RING=-1

# # Model "demo/objaverse_pig.obj", remove back -> OK
# MESH_PATH="demo/objaverse_pig.obj"
# TRIANGLES_TO_REMESH=[ 3052, 3053, 3062, 3063, 1409, 1408, 1398, 1399, 1388, 1389, 3042, 3043 ]
# CONTEXT_RING=-1

# # Model "demo/objaverse_pig.obj", remove ear -> free boundaries
# MESH_PATH="demo/objaverse_pig.obj"
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, [2298], iMaxRingSize=2)
# CONTEXT_RING=-1

# MESH_PATH="./demo/objaverse_pig_CC0.obj"
# #TRIANGLES_TO_REMESH=[2299]  #Ear
# TRIANGLES_TO_REMESH=[1399]  #Back
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=3)
# CONTEXT_RING=-1

# MESH_PATH="./demo/objaverse_pig_CC0_Decim_2k.obj"
# TRIANGLES_TO_REMESH=[1594]  #Back
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=1)
# CONTEXT_RING=-1

# MESH_PATH="./demo/objaverse_pig_CC0_Decim_2k_norm7.obj"
# TRIANGLES_TO_REMESH=[1593]  #Back
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=1)
# CONTEXT_RING=-1

# #-----------------------------------------------------
# MESH_PATH="./demo/NewMesh1_Tri.obj"
# TRIANGLES_TO_REMESH=[394]  #front (+x)
# TRIANGLES_TO_REMESH=[174]  #back (-x)
# TRIANGLES_TO_REMESH=[205]  #top (+z)
# TRIANGLES_TO_REMESH=[275]  #bottom (-z)
# #TRIANGLES_TO_REMESH=[552]  #left (-y)
# #TRIANGLES_TO_REMESH=[85]  #right (+y)

# #TRIANGLES_TO_REMESH=[206]  #Top
# #TRIANGLES_TO_REMESH=[390]  #Tip
# #TRIANGLES_TO_REMESH=[272]  #bottom
# #TRIANGLES_TO_REMESH=[165]  #back
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=2)
# CONTEXT_RING=-1

# #-----------------------------------------------------
# MESH_PATH="./demo/NewMesh2_Tri.obj"
# TRIANGLES_TO_REMESH=[242]  #front (+x)
# TRIANGLES_TO_REMESH=[170]  #back (-x)
# TRIANGLES_TO_REMESH=[276]  #top (+z)
# TRIANGLES_TO_REMESH=[205]  #bottom (-z)
# #TRIANGLES_TO_REMESH=[45]  #left (-y)
# #TRIANGLES_TO_REMESH=[434]  #right (+y)
# mesh = o3d.io.read_triangle_mesh(MESH_PATH)
# TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=2)
# CONTEXT_RING=-1

#-----------------------------------------------------
MESH_PATH="./demo/NewMesh3_Tri.obj"
TRIANGLES_TO_REMESH=[91]  #front (+x)
TRIANGLES_TO_REMESH=[41]  #back (-x)
TRIANGLES_TO_REMESH=[63]  #top (+z)
TRIANGLES_TO_REMESH=[52]  #bottom (-z)
#TRIANGLES_TO_REMESH=[11]  #left (-y)
#TRIANGLES_TO_REMESH=[137]  #right (+y)
mesh = o3d.io.read_triangle_mesh(MESH_PATH)
TRIANGLES_TO_REMESH=ExtractRingsAroundTriangles(mesh, TRIANGLES_TO_REMESH, iMaxRingSize=2)
CONTEXT_RING=-1

In [5]:
#--- Load and normalize mesh
mesh = o3d.io.read_triangle_mesh(MESH_PATH)

# #-- Rotate the mesh to align the area to remesh with the vector (1, 1, 1)
# mesh = AlignBoundaryWithVector(mesh, TRIANGLES_TO_REMESH, vector=(0, 0, 1))
# o3d.io.write_triangle_mesh(OUTPUT_DIR+"/"+"AlignBoundaryWithVector_"+os.path.split(MESH_PATH)[1], mesh)

vertices = np.asarray(mesh.vertices)
vertices = center_vertices(vertices)
vertices = normalize_vertices_scale(vertices)
vertices = np.clip(vertices, a_min=-0.5, a_max = 0.5)
triangles = np.asarray(mesh.triangles)

#-- Reorder mesh elements
he_mesh, _, _, TRIANGLES_TO_REMESH=prepare_halfedge_mesh(vertices, triangles, TRIANGLES_TO_REMESH)
# he_mesh = o3d.geometry.TriangleMesh()
# he_mesh.vertices = o3d.utility.Vector3dVector(vertices)
# he_mesh.triangles = o3d.utility.Vector3iVector(triangles)

#Debug : save normalized and reorderedmesh
localMesh = o3d.geometry.TriangleMesh()
localMesh.vertices = he_mesh.vertices
localMesh.triangles = he_mesh.triangles
o3d.io.write_triangle_mesh(OUTPUT_DIR+"/"+"normalized_"+os.path.split(MESH_PATH)[1], localMesh)
SaveOBJ(he_mesh.vertices, he_mesh.triangles, OUTPUT_DIR+"/"+"normalizedColored_"+os.path.split(MESH_PATH)[1]+".obj", True)

#--- Extract boundary to remesh
submesh, remeshBoundary, sampledPoints=GenerateAreaToRemesh(localMesh, TRIANGLES_TO_REMESH, iMaxRingSize=CONTEXT_RING, iNbSamples=NB_SAMPLING_POINTS)

#Debug : Save area to remesh in OBJ
SaveAreaToRemeshInOBJ(submesh, remeshBoundary, OUTPUT_DIR+"/"+"area2remesh_"+os.path.split(MESH_PATH)[1], iSampledPoints=sampledPoints)
o3d.io.write_point_cloud(OUTPUT_DIR+"/"+"Sampling_"+os.path.split(MESH_PATH)[1]+".xyz", sampledPoints)

#Check number of faces
print("Number of faces to remesh: ", len(submesh.triangles))
if len(submesh.triangles) >= DECIMATION_TARGET_NFACES:
    raise Exception("@@@@ Number of faces to remesh is larger than target number of faces")

Sampled 8192 points on the submesh
19 boundary edges on local area
Number of faces to remesh:  103


In [6]:
#Point cloud sampling structures
pc_array = np.asarray(sampledPoints.points)
pc = torch.tensor(pc_array).unsqueeze(0).float().cuda()

#Halfedge mesh
halfEdgeTriangularMesh=o3d.geometry.HalfEdgeTriangleMesh.create_from_triangle_mesh(submesh)

In [7]:
# Set up model
transformer = TreeMeshGPT(quant_bit = 7 if VERSION == "7bit" else 9, max_seq_len=13000) # can set higher max_seq_len if GPU is L4 or A100
transformer.load(CKPT_PATH)
accelerator = Accelerator(mixed_precision="fp16")
transformer = accelerator.prepare(transformer)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [8]:
# Generation
with accelerator.autocast(), torch.no_grad():
    out_faces = transformer.generate_completion(halfEdgeTriangularMesh, remeshBoundary, pc, n = 0.25)
    #out_faces = transformer.generate_completion(halfEdgeTriangularMesh, remeshBoundary, pc, n = 0.0)

Starting initialization with existing mesh...
Stack size: 1
Stack size: 3
Stack size: 4
Stack size: 5
Stack size: 5
Stack size: 5
Stack size: 6
Stack size: 7
Stack size: 7
Stack size: 8
Stack size: 8
Stack size: 9
Stack size: 9
Stack size: 10
Stack size: 11
Stack size: 11
Stack size: 11
Stack size: 12
Stack size: 13
Stack size: 13
Stack size: 14
Stack size: 14
Stack size: 15
Stack size: 16
Stack size: 16
Stack size: 17
Stack size: 17
Stack size: 18
Stack size: 19
Stack size: 19
Stack size: 19
Stack size: 20
Stack size: 21
Stack size: 21
Stack size: 22
Stack size: 22
Stack size: 23
Stack size: 24
Stack size: 24
Stack size: 25
Stack size: 26
Stack size: 26
Stack size: 27
Stack size: 28
Stack size: 28
Stack size: 29
Stack size: 28
Stack size: 29
Stack size: 30
Stack size: 30
Stack size: 29
Stack size: 30
Stack size: 31
Stack size: 32
Stack size: 32
Stack size: 31
Stack size: 32
Stack size: 32
Stack size: 31
Stack size: 32
Stack size: 33
Stack size: 34
Stack size: 34
Stack size: 33
Stack s

In [9]:
vertices = out_faces.view(-1, 3).cpu().numpy()
n = vertices.shape[0]
faces = torch.arange(1, n + 1).view(-1, 3).numpy()

with open(OUTPUT_DIR+"/"+"GeneratedVertices_"+os.path.split(MESH_PATH)[1], "w") as file :
  for vertex in vertices :
    file.write(f"v  {vertex[0]}  {vertex[1]}  {vertex[2]}\n")

with open(OUTPUT_DIR+"/"+"GeneratedFaces_"+os.path.split(MESH_PATH)[1], "w") as file :
  for vertex in vertices :
    file.write(f"v  {vertex[0]}  {vertex[1]}  {vertex[2]}\n")

  for face in faces :
    file.write(f"f  {face[0]}  {face[1]}  {face[2]}\n")

if min(min(faces.tolist())) == 1:
    faces = (np.array(faces) - 1)

# Remove collapsed triangles and duplicates
p0 = vertices[faces[:, 0]]
p1 = vertices[faces[:, 1]]
p2 = vertices[faces[:, 2]]
collapsed_mask = np.all(p0 == p1, axis=1) | np.all(p0 == p2, axis=1) | np.all(p1 == p2, axis=1)
faces = faces[~collapsed_mask]
faces = faces.tolist()
scene_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, force="mesh",
                        merge_primitives=True)
scene_mesh.merge_vertices()
scene_mesh.update_faces(scene_mesh.nondegenerate_faces())
scene_mesh.update_faces(scene_mesh.unique_faces())
scene_mesh.remove_unreferenced_vertices()
scene_mesh.fix_normals()

In [10]:
del out_faces
torch.cuda.empty_cache()

In [11]:
# Plot mesh from: https://colab.research.google.com/drive/1CR_HDvJ2AnjJV3Bf5vwP70K0hx3RcdMb?usp=sharing#scrollTo=kXi90AcckMF5

triangles = np.asarray(scene_mesh.faces)
vertices = np.asarray(scene_mesh.vertices)
colors = None

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(triangles)

if not mesh.has_vertex_normals(): mesh.compute_vertex_normals()
if not mesh.has_triangle_normals(): mesh.compute_triangle_normals()

if mesh.has_triangle_normals():
    colors = (0.5, 0.5, 0.5) + np.asarray(mesh.triangle_normals) * 0.5
    colors = tuple(map(tuple, colors))
else:
    colors = (1.0, 0.0, 0.0)

import plotly.graph_objects as go

fig = go.Figure(
    data=[
        go.Mesh3d(
            x=vertices[:,0],
            y=vertices[:,1],
            z=vertices[:,2],
            i=triangles[:,0],
            j=triangles[:,1],
            k=triangles[:,2],
            facecolor=colors,
            opacity=0.50)
    ],
    layout=dict(
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False)
        )
    )
)
fig.show()

In [12]:
# Save mesh if necessary
outputFilePath="./output/COMPLETION_"+os.path.split(MESH_PATH)[1]
scene_mesh.export(outputFilePath)

'# https://github.com/mikedh/trimesh\nv -0.12992126 -0.26377952 -0.12992126\nv -0.20866141 -0.34251970 -0.07480314\nv -0.23228347 -0.26377952 -0.09842521\nv -0.26377952 -0.26377952 0.00393701\nv -0.26377952 -0.12992126 0.00393701\nv -0.23228347 -0.12992126 -0.09842521\nv -0.26377952 0.00393701 0.00393701\nv -0.23228347 0.00393701 0.09842521\nv -0.23228347 0.12992126 0.09842521\nv -0.26377952 0.12992126 0.00393701\nv -0.23228347 0.26377952 0.09842521\nv -0.26377952 0.26377952 0.00393701\nv -0.20866141 0.34251970 0.07480317\nv -0.23228347 0.36614174 0.00393701\nv -0.12992126 0.36614174 0.09842521\nv -0.12992126 0.39763778 0.00393701\nv -0.20866141 0.34251970 -0.07480314\nv -0.12992126 0.36614174 -0.09842521\nv -0.03543308 0.36614174 0.00393701\nv -0.05905512 0.34251970 -0.07480314\nv -0.03543308 0.26377952 -0.09842521\nv 0.00393701 0.26377952 0.00393701\nv 0.03543305 0.16929132 0.00393701\nv 0.00393701 0.12992126 0.11417323\nv 0.12992126 0.09842521 0.09842521\nv 0.12992126 0.12992126 0.0