In [1]:
from model.treemeshgpt_inference import TreeMeshGPT
import os
import numpy as np
import open3d as o3d
import torch
from accelerate import Accelerator
from pathlib import Path
from fns import center_vertices, normalize_vertices_scale, str2bool, quantize_verts, dequantize_verts_tensor
import trimesh
import pyvista as pv
import argparse

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
VERSION = "7bit"
CKPT_PATH = "./checkpoints/treemeshgpt_7bit.pt"
MESH_PATH =  "demo/luma_cat.glb"
MESH_PATH =  "demo/FullMesh_1Hole_SplitDisk.obj"
#MESH_PATH =  "demo/FullMesh_1Hole.obj"
#MESH_PATH =  "demo/FullMesh.obj"
MESH_PATH =  "demo/FullMesh_1Hole-2.obj"

#MESH_PATH = "demo/Mesh2_DiskHole.obj"
#MESH_PATH = "demo/Mesh2_Hole.obj"
# MESH_PATH = "demo/Mesh2.obj"
# MESH_PATH = "demo/objaverse_pig_CC0_Decim_2k.obj"
# MESH_PATH = "./demo/NewMesh1_Tri.obj"
MESH_PATH = "./demo/NewMesh3_Tri.obj"

OUTPUT_DIR="./output"

DECIMATION = True
DECIMATION_TARGET_NFACES = 5000
DECIMATION_BOUNDARY_DELETION = True

SAMPLING = "uniform" if VERSION == "7bit" else "fps"
#SAMPLING = "keep_vertices"

if not os.path.exists("./output") :
  os.mkdir("./output")

In [3]:
# Set up model
transformer = TreeMeshGPT(quant_bit = 7 if VERSION == "7bit" else 9, max_seq_len=13000) # can set higher max_seq_len if GPU is L4 or A100
transformer.load(CKPT_PATH)
accelerator = Accelerator(mixed_precision="fp16")
transformer = accelerator.prepare(transformer)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [4]:
# Load and normalize mesh
mesh = o3d.io.read_triangle_mesh(MESH_PATH)
vertices = np.asarray(mesh.vertices)
vertices = center_vertices(vertices)
vertices = normalize_vertices_scale(vertices)
vertices = np.clip(vertices, a_min=-0.5, a_max = 0.5)
triangles = np.asarray(mesh.triangles)

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(triangles)

# Mesh decimation
if DECIMATION:
    n_triangles = min(DECIMATION_TARGET_NFACES, len(triangles))
    faces_pyvista = np.hstack([np.full((triangles.shape[0], 1), 3), triangles]).astype(np.int64).flatten()
    mesh = pv.PolyData(vertices, faces_pyvista)
    decimated_mesh = mesh.decimate_pro(1-n_triangles/len(triangles), boundary_vertex_deletion=DECIMATION_BOUNDARY_DELETION)
    decimated_vertices = np.array(decimated_mesh.points)
    decimated_faces = np.array(decimated_mesh.faces).reshape(-1, 4)[:, 1:]  # Remove leading '3' per triangle
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(decimated_vertices)
    mesh.triangles = o3d.utility.Vector3iVector(decimated_faces)
    print("Mesh is decimated to {} faces".format(len(decimated_faces)))
else:
    print("Sampling from original mesh with {} faces".format(len(triangles)))

o3d.io.write_triangle_mesh(OUTPUT_DIR+"/"+"normalized_"+os.path.split(MESH_PATH)[1], mesh)

Mesh is decimated to 144 faces


True

In [5]:
# Point cloud sampling
if SAMPLING == "uniform":
    pc = mesh.sample_points_uniformly(number_of_points=8192)
    o3d.io.write_point_cloud(OUTPUT_DIR+"/"+"Sampling_Uniform_"+os.path.split(MESH_PATH)[1]+".ply", pc)
elif SAMPLING == "fps":
    pc = mesh.sample_points_uniformly(number_of_points=8192*10)
    pc_array = np.asarray(pc.points)
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(pc_array)
    pc = pc.farthest_point_down_sample(8192//2)
    o3d.io.write_point_cloud(OUTPUT_DIR+"/"+"Sampling_FPS_"+os.path.split(MESH_PATH)[1]+".ply", pc)
elif SAMPLING == "keep_vertices" :
    pc = o3d.geometry.PointCloud()
    pc.points = mesh.vertices
    o3d.io.write_point_cloud(OUTPUT_DIR+"/"+"Sampling_KeepVertices_"+os.path.split(MESH_PATH)[1]+".ply", pc)
pc_array = np.asarray(pc.points)
pc = torch.tensor(pc_array).unsqueeze(0).float().cuda()

In [6]:
# Generation
with accelerator.autocast(), torch.no_grad():
    out_faces = transformer.generate(pc,n = 0.25)

Sequence length: 222/13000 | Stack length: 69  EOS symbol emited, stopping generation, stack length: 69
Stack length after while/loop : 69


In [7]:
vertices = out_faces.view(-1, 3).cpu().numpy()
n = vertices.shape[0]
faces = torch.arange(1, n + 1).view(-1, 3).numpy()

with open(OUTPUT_DIR+"/"+"GeneratedVertices_"+os.path.split(MESH_PATH)[1], "w") as file :
  for vertex in vertices :
    file.write(f"v  {vertex[0]}  {vertex[1]}  {vertex[2]}\n")

if min(min(faces.tolist())) == 1:
    faces = (np.array(faces) - 1)

# Remove collapsed triangles and duplicates
p0 = vertices[faces[:, 0]]
p1 = vertices[faces[:, 1]]
p2 = vertices[faces[:, 2]]
collapsed_mask = np.all(p0 == p1, axis=1) | np.all(p0 == p2, axis=1) | np.all(p1 == p2, axis=1)
faces = faces[~collapsed_mask]
faces = faces.tolist()
scene_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, force="mesh",
                        merge_primitives=True)
scene_mesh.merge_vertices()
scene_mesh.update_faces(scene_mesh.nondegenerate_faces())
scene_mesh.update_faces(scene_mesh.unique_faces())
scene_mesh.remove_unreferenced_vertices()
scene_mesh.fix_normals()

In [8]:
del out_faces
torch.cuda.empty_cache()

In [9]:
# Plot mesh from: https://colab.research.google.com/drive/1CR_HDvJ2AnjJV3Bf5vwP70K0hx3RcdMb?usp=sharing#scrollTo=kXi90AcckMF5

triangles = np.asarray(scene_mesh.faces)
vertices = np.asarray(scene_mesh.vertices)
colors = None

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(triangles)

if not mesh.has_vertex_normals(): mesh.compute_vertex_normals()
if not mesh.has_triangle_normals(): mesh.compute_triangle_normals()

if mesh.has_triangle_normals():
    colors = (0.5, 0.5, 0.5) + np.asarray(mesh.triangle_normals) * 0.5
    colors = tuple(map(tuple, colors))
else:
    colors = (1.0, 0.0, 0.0)

import plotly.graph_objects as go

fig = go.Figure(
    data=[
        go.Mesh3d(
            x=vertices[:,0],
            y=vertices[:,1],
            z=vertices[:,2],
            i=triangles[:,0],
            j=triangles[:,1],
            k=triangles[:,2],
            facecolor=colors,
            opacity=0.50)
    ],
    layout=dict(
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False)
        )
    )
)
fig.show()

In [10]:
# Save mesh if necessary
outputFilePath="./output/GENERATE_"+os.path.split(MESH_PATH)[1]
scene_mesh.export(outputFilePath)


'# https://github.com/mikedh/trimesh\nv -0.12992126 -0.26377952 -0.12992126\nv -0.23228347 -0.26377952 -0.09842521\nv -0.13779527 -0.13779527 -0.12992126\nv -0.23228347 -0.13779527 -0.09842521\nv -0.26377952 -0.26377952 -0.00393701\nv -0.20866141 -0.34251970 -0.07480314\nv -0.12992126 -0.36614174 -0.09842521\nv -0.03543308 -0.27165353 -0.09842521\nv 0.00393701 -0.12992126 -0.11417323\nv -0.12992126 0.00393701 -0.12992126\nv -0.23228347 0.00393701 -0.09842521\nv -0.26377952 0.00393701 0.00393701\nv -0.26377952 -0.13779527 -0.00393701\nv -0.23228347 -0.26377952 0.09842521\nv -0.23228347 -0.36614174 -0.00393701\nv -0.12992126 -0.39763778 -0.00393701\nv -0.05905512 -0.34251970 -0.07480314\nv -0.03543308 -0.36614174 -0.00393701\nv -0.00393701 -0.27165353 -0.00393701\nv 0.03543305 -0.16929135 -0.00393701\nv 0.13779527 -0.09842521 -0.09842521\nv 0.12992126 -0.00393701 -0.12992126\nv 0.00393701 0.00393701 -0.12992126\nv -0.12992126 0.13779527 -0.12992126\nv -0.23228347 0.13779527 -0.09842521\n