# This notebook does the whole pipeline starting from the raw dataset

### Import Libraries

In [None]:
import os
import numpy as np
import trimesh
import pyrender
import matplotlib.pyplot as plt
# from dataset_processing.fix_normal import traverse_and_fix

import trimesh
import numpy as np
from trimesh.visual import ColorVisuals
from trimesh.scene.lighting import DirectionalLight, PointLight

# sanity = trimesh.load("/home/athiwat/progressive_img2sketch/resources/LOD_for_icp/46/lod2.obj", process=True)
# # to mesh
# if isinstance(sanity, trimesh.Scene):
#     sanity = sanity.dump(concatenate=True)
# trimesh.repair.fix_normals(sanity)
# trimesh.repair.fix_inversion(sanity)
# trimesh.repair.fix_winding(sanity)
# trimesh.repair.broken_faces(sanity)
# sanity.show()

  sanity = sanity.dump(concatenate=True)


### 1. Correcting the UV Paths for the LOD dataset

just need to run this once

In [2]:
RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD_for_icp"

# traverse_and_fix(RAW_LOD_DATASET_ROOT)

In [3]:
import trimesh
import pyrender
import numpy as np
import matplotlib.pyplot as plt

def center_scene_by_bbox(scene: trimesh.Scene) -> trimesh.Scene:
    """
    Center the scene at the origin based on its bounding-box center.
    """
    min_corner, max_corner = scene.bounds
    center = (min_corner + max_corner) / 2.0
    scene.apply_translation(-center)
    return scene

def get_registration_matrix(
    source_mesh: trimesh.Trimesh,
    target_mesh: trimesh.Trimesh,
    samples: int = 3000,
    icp_first: int = 1,
    icp_final: int = 30
) -> np.ndarray:
    """
    Compute the ICP transformation matrix that aligns source_mesh to target_mesh.
    """
    matrix, _ = trimesh.registration.mesh_other(
        source_mesh,
        target_mesh,
        samples=samples,
        scale=False,
        icp_first=icp_first,
        icp_final=icp_final
    )
    return matrix


def align_lods(scenes: dict[int, trimesh.Scene], center_before: bool = False):
    # — step 1: (optional) rough centering to help ICP converge —
    if center_before:
        for lod in scenes:
            scenes[lod] = center_scene_by_bbox(scenes[lod])

    # — step 2: extract single meshes for ICP —
    
    meshes = {
        lod: trimesh.util.concatenate(list(scenes[lod].geometry.values()))
        for lod in scenes
    }

    #show original bbox centers
    for lod, mesh in meshes.items():
        min_corner, max_corner = mesh.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} original center: {center}")
        
    # ICP: 2→1 then 3→2
    t2_1 = get_registration_matrix(meshes[2], meshes[1])
    t3_2 = get_registration_matrix(meshes[3], meshes[2])

    # apply those transforms
    scenes[2].apply_transform(t2_1)
    scenes[3].apply_transform(t2_1 @ t3_2)

    # show aligned bbox centers
    for lod, scene in scenes.items():
        min_corner, max_corner = scene.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} aligned center: {center}")
    # — step 3: **final centering** based on aligned LOD1 bbox —
    min1, max1 = scenes[1].bounds
    center1 = (min1 + max1) * 0.5
    for lod in scenes:
        scenes[lod].apply_translation(-center1)

    return scenes



def look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
    """
    Create a camera-to-world pose matrix for pyrender given eye, target, up vectors.
    """
    f = (target - eye)
    f /= np.linalg.norm(f)
    # avoid parallel up/f
    if np.isclose(np.linalg.norm(np.cross(f, up)), 0):
        up = np.array([0, 0, 1]) if np.isclose(abs(f.dot([0, 1, 0])), 1) else np.array([0, 1, 0])
    s = np.cross(f, up); s /= np.linalg.norm(s)
    u = np.cross(s, f); u /= np.linalg.norm(u)

    # view matrix (world→camera)
    view = np.array([
        [ s[0],  s[1],  s[2], -s.dot(eye)],
        [ u[0],  u[1],  u[2], -u.dot(eye)],
        [-f[0], -f[1], -f[2],  f.dot(eye)],
        [    0,     0,     0,           1]
    ])
    # invert → camera pose (camera→world)
    return np.linalg.inv(view)

SHOW_Y_AXIS = True  # Set to True to visualize the Y-axis



def align_lods_1_2_only(meshes: dict[int, trimesh.base.Trimesh], center_before: bool = False, samples: int = 3000):
    # Show original bbox centers
    for lod, mesh in [(1, meshes[1]), (2, meshes[2])]:
        min_corner, max_corner = mesh.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} original center: {center}")
        
    # — step 1: (optional) rough centering to help ICP converge —
    if center_before:
        for lod in [1, 2]:
            meshes[lod] = center_scene_by_bbox(meshes[lod])

    # — step 2: extract meshes —

    mesh1 = meshes[1]
    mesh2 = meshes[2]
    # ICP: 2 → 1
    t2_1 = get_registration_matrix(mesh2, mesh1, samples=samples)
    meshes[2].apply_transform(t2_1)

    # Show aligned bbox centers
    for lod in [1, 2]:
        min_corner, max_corner = meshes[lod].bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} aligned center: {center}")

    # — step 3: center both based on aligned LOD1 —
    min1, max1 = meshes[1].bounds
    center1 = (min1 + max1) * 0.5
    for lod in [1, 2]:
        meshes[lod].apply_translation(-center1)

    return meshes


### 2. Load the scene and orbit capture

In [5]:
import trimesh
import numpy as np
import os

# ─── Config ───────────────────────────────────────────────────────
scene_num = 46
LODS = [1, 2]
threshold_degrees = 5.0
angle_thresh = np.deg2rad(threshold_degrees)
# RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD_data_50"
RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD_for_icp"

# ─── 1. Load LOD meshes into dict ─────────────────────────────────
lod_meshes = {}
for lod in LODS:
    path = os.path.join(RAW_LOD_DATASET_ROOT, str(scene_num), f"lod{lod}.obj")
    loaded = trimesh.load(path, process=False)
    mesh = (
        trimesh.util.concatenate(loaded.geometry.values())
        if isinstance(loaded, trimesh.Scene)
        else loaded
    )
    lod_meshes[lod] = mesh

print(f"check type of lod_meshes[1]: {type(lod_meshes[1])}")
print(f"check type of lod_meshes[2]: {type(lod_meshes[2])}")
# ─── 2. Align meshes ──────────────────────────────────────────────
aligned_meshes = align_lods_1_2_only(lod_meshes, center_before=True, samples=4000)
# aligned_meshes = lod_meshes.copy()
scene = trimesh.Scene()
# for lod, mesh in aligned_meshes.items():
#     scene.add_geometry(mesh, geom_name=f"LOD{lod}")
scene.add_geometry(aligned_meshes[2], geom_name="LOD1")

scene.show()
# # ─── 3. Build scene dict with crease lines ────────────────────────
# scene_dict = {}
# for lod, mesh in aligned_meshes.items():
#     # Step 1: Create base scene and add the aligned mesh
#     scene = trimesh.Scene()
#     scene.add_geometry(mesh)

#     # Step 2: Weld mesh to prepare for crease detection
#     welded = trimesh.Trimesh(vertices=mesh.vertices.copy(),
#                              faces=mesh.faces.copy(),
#                              process=True)

#     # Step 3: Crease detection
#     fa = welded.face_adjacency_angles
#     edges = welded.face_adjacency_edges
#     mask = fa > angle_thresh

#     # Optional: filter for manifold edges only (edge shared by 2 faces)
#     from collections import defaultdict
#     edge_count = defaultdict(int)
#     for face in welded.faces:
#         for i in range(3):
#             e = tuple(sorted((face[i], face[(i+1)%3])))
#             edge_count[e] += 1
#     filtered_edges = [e for e in edges[mask] if edge_count[tuple(sorted(e))] == 2]

#     if len(filtered_edges) > 0:
#         segments = welded.vertices[np.array(filtered_edges)]
#         crease = trimesh.load_path(segments)
#         crease.colors = np.tile([0, 0, 0, 255], (len(crease.entities), 1))
#         scene.add_geometry(crease)

#     # Save scene
#     scene_dict[lod] = scene

# # ─── Show LOD1 as test ────────────────────────────────────────────
# scene_dict[1].show()

check type of lod_meshes[1]: <class 'trimesh.base.Trimesh'>
check type of lod_meshes[2]: <class 'trimesh.base.Trimesh'>
LOD 1 original center: [ 7.10542736e-15 -4.44089210e-16  7.10542736e-15]
LOD 2 original center: [7.10542736e-15 0.00000000e+00 7.10542736e-15]
LOD 1 aligned center: [ 7.10542736e-15 -4.44089210e-16  7.10542736e-15]
LOD 2 aligned center: [-0.03559707 -0.1268338  -0.10113088]
