# This notebook does the whole pipeline starting from the raw dataset

### Import Libraries

In [1]:
import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"  # must be set before importing pyrender

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import trimesh
from trimesh.visual import ColorVisuals
from trimesh.scene.lighting import DirectionalLight, PointLight

import pyrender
from pyrender import Primitive, Mesh as PyMesh, PerspectiveCamera, SpotLight, OffscreenRenderer
from pyrender.constants import RenderFlags

# test_mesh_path = "/home/athiwat/progressive_img2sketch/resources/LOD50_opaque_normalized_triangulated/49/lod1.obj"
# trimesh_scene = trimesh.load(test_mesh_path)

# print(trimesh_scene.vertices.shape)  # (#V, 3)
# print(trimesh_scene.faces.shape)     # (#F, 3) — these are still triangles

# trimesh_scene.show()


### 1. Correcting the UV Paths for the LOD dataset

just need to run this once

In [2]:
SHOW_Y_AXIS = True  # Set to True to visualize the Y-axis
# RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD_for_icp"


# traverse_and_fix(RAW_LOD_DATASET_ROOT)

In [None]:

def center_scene_by_bbox(scene: trimesh.Scene) -> trimesh.Scene:
    """
    Center the scene at the origin based on its bounding-box center.
    """
    min_corner, max_corner = scene.bounds
    center = (min_corner + max_corner) / 2.0
    scene.apply_translation(-center)
    return scene

def get_registration_matrix(
    source_mesh: trimesh.Trimesh,
    target_mesh: trimesh.Trimesh,
    samples: int = 3000,
    icp_first: int = 1,
    icp_final: int = 30
) -> np.ndarray:
    """
    Compute the ICP transformation matrix that aligns source_mesh to target_mesh.
    """
    matrix, _ = trimesh.registration.mesh_other(
        source_mesh,
        target_mesh,
        samples=samples,
        scale=False,
        icp_first=icp_first,
        icp_final=icp_final
    )
    return matrix

def align_lods(scenes: dict[int, trimesh.Scene], center_before: bool = False):
    # — step 1: (optional) rough centering to help ICP converge —
    if center_before:
        for lod in scenes:
            scenes[lod] = center_scene_by_bbox(scenes[lod])

    # — step 2: extract single meshes for ICP —
    
    meshes = {
        lod: trimesh.util.concatenate(list(scenes[lod].geometry.values()))
        for lod in scenes
    }

    #show original bbox centers
    for lod, mesh in meshes.items():
        min_corner, max_corner = mesh.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} original center: {center}")
        
    # ICP: 2→1 then 3→2
    t2_1 = get_registration_matrix(meshes[2], meshes[1])
    t3_2 = get_registration_matrix(meshes[3], meshes[2])

    # apply those transforms
    scenes[2].apply_transform(t2_1)
    scenes[3].apply_transform(t2_1 @ t3_2)

    # show aligned bbox centers
    for lod, scene in scenes.items():
        min_corner, max_corner = scene.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} aligned center: {center}")
    # — step 3: **final centering** based on aligned LOD1 bbox —
    min1, max1 = scenes[1].bounds
    center1 = (min1 + max1) * 0.5
    for lod in scenes:
        scenes[lod].apply_translation(-center1)

    return scenes

def look_at_matrix(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
    """
    Create a camera-to-world pose matrix for pyrender given eye, target, up vectors.
    """
    f = (target - eye)
    f /= np.linalg.norm(f)
    # avoid parallel up/f
    if np.isclose(np.linalg.norm(np.cross(f, up)), 0):
        up = np.array([0, 0, 1]) if np.isclose(abs(f.dot([0, 1, 0])), 1) else np.array([0, 1, 0])
    s = np.cross(f, up); s /= np.linalg.norm(s)
    u = np.cross(s, f); u /= np.linalg.norm(u)

    # view matrix (world→camera)
    view = np.array([
        [ s[0],  s[1],  s[2], -s.dot(eye)],
        [ u[0],  u[1],  u[2], -u.dot(eye)],
        [-f[0], -f[1], -f[2],  f.dot(eye)],
        [    0,     0,     0,           1]
    ])
    # invert → camera pose (camera→world)
    return np.linalg.inv(view)

def align_lods_1_2_only(meshes: dict[int, trimesh.base.Trimesh], center_before: bool = False, samples: int = 3000):
    # Show original bbox centers
    for lod, mesh in [(1, meshes[1]), (2, meshes[2])]:
        min_corner, max_corner = mesh.bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} original center: {center}")
        
    # — step 1: (optional) rough centering to help ICP converge —
    if center_before:
        for lod in [1, 2]:
            meshes[lod] = center_scene_by_bbox(meshes[lod])

    # — step 2: extract meshes —

    mesh1 = meshes[1]
    mesh2 = meshes[2]
    # ICP: 2 → 1
    t2_1 = get_registration_matrix(mesh2, mesh1, samples=samples)
    meshes[2].apply_transform(t2_1)

    # Show aligned bbox centers
    for lod in [1, 2]:
        min_corner, max_corner = meshes[lod].bounds
        center = (min_corner + max_corner) / 2.0
        print(f"LOD {lod} aligned center: {center}")

    # — step 3: center both based on aligned LOD1 —
    min1, max1 = meshes[1].bounds
    center1 = (min1 + max1) * 0.5
    for lod in [1, 2]:
        meshes[lod].apply_translation(-center1)

    return meshes

def render_orbit_with_creases(mesh, line_mesh, lod_meshes, scene_number, lod, output_root,
                               azimuths, elevations, width=2048, height=2048):
    """
    Orbit render for one mesh+line using pyrender, transparent background.
    """
    # Compute camera orbit radius
    max_bbox = max([m.bounding_box.extents.max() for m in lod_meshes.values()])
    radius = max_bbox * 1.5
    target = np.array([0.0, 0.0, 0.0])

    renderer = pyrender.OffscreenRenderer(viewport_width=width, viewport_height=height)

    for az in azimuths:
        for el in elevations:
            # spherical → cartesian
            rad_az = np.deg2rad(az)
            rad_el = np.deg2rad(el)
            x = radius * np.cos(rad_el) * np.sin(rad_az)
            y = radius * np.sin(rad_el)
            z = radius * np.cos(rad_el) * np.cos(rad_az)
            eye = np.array([x, y, z])

            # Setup scene
            scene = pyrender.Scene()
            # scene = pyrender.Scene(bg_color=[255, 255, 255, 0], ambient_light=[0.8, 0.8, 0.8])
            
            cam_pose = look_at_matrix(eye, target, up=np.array([0, 1, 0]))
            camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=width / height)
            scene.add(camera, pose=cam_pose)

            # In your render_orbit_with_creases function, replace:
            # mesh_tex = PyMesh.from_trimesh(mesh, smooth=True)

            # With:
            mesh_copy = mesh.copy()
            # Force all materials to be opaque
            if hasattr(mesh_copy.visual, 'material'):
                mesh_copy.visual.material.alphaMode = 'OPAQUE'

            mesh_tex = PyMesh.from_trimesh(mesh_copy, smooth=True)
            scene.add(mesh_tex)
            
            if line_mesh is not None:
                scene.add(line_mesh)

            color, _ = renderer.render(scene, flags=RenderFlags.RGBA)

            # Save to file
            save_dir = os.path.join(output_root, str(scene_number), f"lod{lod}")
            os.makedirs(save_dir, exist_ok=True)
            filename = f"lod{lod}_az{az:03d}_el{el:02d}.png"
            save_path = os.path.join(save_dir, filename)

            if color.shape[-1] == 4:
                Image.fromarray(color, mode="RGBA").save(save_path)
            elif color.shape[-1] == 3:
                Image.fromarray(color, mode="RGB").save(save_path)
            else:
                raise ValueError(f"Unexpected image shape: {color.shape}")


    renderer.delete()

def line_segments_to_cylinders(vertices, edges, radius=0.1, sections=6):
    cylinders = []
    for edge in edges:
        start = vertices[edge[0]]
        end   = vertices[edge[1]]
        direction = end - start
        height = np.linalg.norm(direction)
        if height < 1e-6:
            continue

        # Create a base cylinder aligned to z-axis
        cyl = trimesh.creation.cylinder(radius=radius, height=height, sections=sections)
        cyl.apply_translation([0, 0, height / 2.0])  # base at origin

        # Rotate to align with actual direction
        cyl.apply_transform(trimesh.geometry.align_vectors([0, 0, 1], direction))

        # Translate to start point
        cyl.apply_translation(start)

        cylinders.append(cyl)

    if not cylinders:
        return None

    # Combine all into one mesh
    combined = trimesh.util.concatenate(cylinders)

    # Assign black color (RGBA = 0,0,0,255) to each face
    black_color = np.tile([0, 0, 0, 255], (len(combined.faces), 1))  # uint8 by default
    combined.visual.face_colors = black_color

    return pyrender.Mesh.from_trimesh(combined, smooth=False)

def force_opaque_materials(mesh):
    """Force all materials to be opaque"""
    if hasattr(mesh.visual, 'material'):
        if hasattr(mesh.visual.material, 'diffuse'):
            # Ensure alpha is 1.0 for diffuse color
            if len(mesh.visual.material.diffuse) == 4:
                mesh.visual.material.diffuse[3] = 0.0
        if hasattr(mesh.visual.material, 'baseColorFactor'):
            # For PBR materials
            if len(mesh.visual.material.baseColorFactor) == 4:
                mesh.visual.material.baseColorFactor[3] = 0.0

def make_faces_opaque_preserve_texture(meshes):
    """
    Sets alpha channel of face colors to 255 (opaque), preserving RGB values.
    Works for both RGB and RGBA face color arrays.
    """
    for lod, mesh in meshes.items():
        print(f"Processing LOD {lod} transparency...")
        
        # Handle face colors
        if hasattr(mesh.visual, 'face_colors') and mesh.visual.face_colors is not None:
            fc = mesh.visual.face_colors
            print(f"  Face colors shape: {fc.shape}")
            
            if fc.shape[1] == 4:
                # Already RGBA, set alpha to 255
                mesh.visual.face_colors[:, 3] = 255
                print(f"  Set alpha to 255 for {len(fc)} faces")
            elif fc.shape[1] == 3:
                # RGB only, add alpha channel
                alpha = np.full((fc.shape[0], 1), 255, dtype=fc.dtype)
                mesh.visual.face_colors = np.hstack((fc, alpha))
                print(f"  Added alpha channel to {len(fc)} faces")
        
        # Handle vertex colors if present
        if hasattr(mesh.visual, 'vertex_colors') and mesh.visual.vertex_colors is not None:
            vc = mesh.visual.vertex_colors
            if vc.shape[1] == 4:
                mesh.visual.vertex_colors[:, 3] = 255
            elif vc.shape[1] == 3:
                alpha = np.full((vc.shape[0], 1), 255, dtype=vc.dtype)
                mesh.visual.vertex_colors = np.hstack((vc, alpha))
        
        # Force material opacity
        force_opaque_materials(mesh)
                
def debug_face_colors(meshes):
    """Debug function to check face color transparency"""
    for lod, mesh in meshes.items():
        fc = getattr(mesh.visual, "face_colors", None)
        if fc is not None:
            print(f"LOD {lod} face colors shape: {fc.shape}")
            print(f"LOD {lod} alpha channel min/max: {fc[:, 3].min()}/{fc[:, 3].max()}")
            print(f"LOD {lod} unique alpha values: {np.unique(fc[:, 3])}")
        else:
            print(f"LOD {lod} has no face colors")

### 2. Load the scene and orbit capture

In [4]:
# ─── Config ───────────────────────────────────────────────────────
# scene_num = 46

threshold_degrees = 5.0
angle_thresh = np.deg2rad(threshold_degrees)
# RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD_data_50"
RAW_LOD_DATASET_ROOT = "/home/athiwat/progressive_img2sketch/resources/LOD50_opaque_normalized_triangulated"

SCENES = range(46, 51)  # Assuming scenes are numbered from 1 to 50
LODS = [1, 2]

for scene_num in SCENES:
    # ─── 1. Load LOD meshes into dict ─────────────────────────────────
    lod_meshes = {}
    for lod in LODS:
        path = os.path.join(RAW_LOD_DATASET_ROOT, str(scene_num), f"lod{lod}.obj")
        loaded = trimesh.load(path, process=False)
        lod_mesh = (
            trimesh.util.concatenate(loaded.geometry.values())
            if isinstance(loaded, trimesh.Scene)
            else loaded
        )
        lod_meshes[lod] = lod_mesh

    # ─── 2. Align meshes ──────────────────────────────────────────────
    aligned_meshes = align_lods_1_2_only(lod_meshes, center_before=True, samples=4000)
    
    # ─── 2.1 Sanitize transparency ────────────────────────────────────
    # ─── 2.1 Sanitize transparency ────────────────────────────────────
    # debug_face_colors(aligned_meshes)  # Debug first
    # make_faces_opaque_preserve_texture(aligned_meshes)
    # debug_face_colors(aligned_meshes)  # Check after fix

    scene = trimesh.Scene()
    # # ─── 3. Build scene dict with crease lines ────────────────────────
    scene_dict = {}
    
    for lod, mesh in aligned_meshes.items():
        print(f"Processing LOD{lod}...")

        # Step 1: Weld mesh for edge adjacency analysis
        welded = trimesh.Trimesh(vertices=mesh.vertices.copy(),
                                faces=mesh.faces.copy(),
                                process=True)
        # Step 2: Detect creases
        fa = welded.face_adjacency_angles
        edges = welded.face_adjacency_edges
        mask = fa > angle_thresh
        
        # Filter to manifold edges only
        from collections import defaultdict
        edge_count = defaultdict(int)
        for face in welded.faces:
            for i in range(3):
                e = tuple(sorted((face[i], face[(i+1)%3])))
                edge_count[e] += 1
        filtered_edges = [e for e in edges[mask] if edge_count[tuple(sorted(e))] == 2]

        # Step 3: Create pyrender line mesh for rendering
        line_mesh = line_segments_to_cylinders(welded.vertices, filtered_edges)

        # Step 4: Render together with the mesh
        AZIMUTH_STEP = 45
        ELEVATIONS = [0, 30]
        OUTPUT_ROOT = "/home/athiwat/progressive_img2sketch/resources/test_orbit"  # customize this

        render_orbit_with_creases(
            mesh=mesh,
            line_mesh=line_mesh,
            lod_meshes=aligned_meshes,
            scene_number=scene_num,
            lod=lod,
            output_root=OUTPUT_ROOT,
            azimuths=range(0, 360, AZIMUTH_STEP),
            elevations=ELEVATIONS
        )

LOD 1 original center: [0.e+00 1.e-06 0.e+00]
LOD 2 original center: [0.e+00 1.e-06 0.e+00]
LOD 1 aligned center: [0. 0. 0.]
LOD 2 aligned center: [-1.53834609 -0.14552135 -2.9632285 ]
Processing LOD1...


  Image.fromarray(color, mode="RGB").save(save_path)


Processing LOD2...
LOD 1 original center: [0.e+00 5.e-07 0.e+00]
LOD 2 original center: [ 0.00000000e+00 -5.00000001e-07  0.00000000e+00]
LOD 1 aligned center: [0. 0. 0.]
LOD 2 aligned center: [-0.01959294  0.65156799  0.01684448]
Processing LOD1...
Processing LOD2...
LOD 1 original center: [0. 0. 0.]
LOD 2 original center: [0. 0. 0.]
LOD 1 aligned center: [0. 0. 0.]
LOD 2 aligned center: [-0.00043213 -0.00210114 -0.00912404]
Processing LOD1...
Processing LOD2...
LOD 1 original center: [0. 0. 0.]
LOD 2 original center: [0. 0. 0.]
LOD 1 aligned center: [0. 0. 0.]
LOD 2 aligned center: [-0.03830383 -1.17018868 -1.00643318]
Processing LOD1...
Processing LOD2...
LOD 1 original center: [ 0.e+00 -5.e-07  0.e+00]
LOD 2 original center: [ 0.e+00 -5.e-07  0.e+00]


KeyboardInterrupt: 