In [None]:
import trimesh
from trimesh.visual.texture import TextureVisuals, SimpleMaterial
import open3d as o3d
from PIL import Image
import numpy as np
import pyrender
import math
from tqdm import tqdm
import os


In [None]:
def fps_sample_vertices(mesh: trimesh.Trimesh,
                        max_samples: int = 100,
                        seed: int = 42) -> np.ndarray:
    """
    Uniformly sample vertices from a mesh using farthest point sampling.
    """
    np.random.seed(seed)
    vertices = np.array(mesh.vertices)
    N = len(vertices)
    
    # pick a random vertex as the first sample
    first_idx = np.random.randint(N)
    selected_idxs = [first_idx]
    
    # 2) calculate the distance from the first vertex to all other vertices
    dist = np.linalg.norm(vertices - vertices[first_idx], axis=1)
    
    # 3) select the next vertex
    for _ in range(1, max_samples):
        # 1) find the vertex with the maximum distance
        next_idx = np.argmax(dist)
        selected_idxs.append(next_idx)
        
        # 2) update the distance to the selected vertex
        new_dist = np.linalg.norm(vertices - vertices[next_idx], axis=1)
        dist = np.minimum(dist, new_dist)
    
    return vertices[selected_idxs]

In [None]:

# ----------------------------------------
# Basic parameters
# ----------------------------------------
DATA_DIR      = '../3d_fish'
INPUT_MESH    = 'reference_fish_smpl23_resized.glb'
SCALE_RATIO   = 0.06
NUM_INSTANCES = 360


# original mesh load and scale
reference_fish = trimesh.load(os.path.join(DATA_DIR, INPUT_MESH))
reference_fish = reference_fish.dump()
reference_fish = trimesh.util.concatenate(reference_fish)

In [None]:
# fps sampling 
points = fps_sample_vertices(reference_fish, max_samples=NUM_INSTANCES)

In [None]:
import numpy as np
import plotly.graph_objs as go
from matplotlib import cm

def visualize_sorted_vertices_plotly(vertices):
    vertices = np.array(vertices)
    sorted_indices = np.argsort(vertices[:, 0])
    sorted_vertices = vertices[sorted_indices]

    # Rainbow color mapping
    norm = np.linspace(0, 1, len(sorted_vertices))
    colors = cm.rainbow(norm)
    rgb_colors = (colors[:, :3] * 255).astype(int)  # Convert to RGB

    trace = go.Scatter3d(
        x=sorted_vertices[:, 0],
        y=sorted_vertices[:, 1],
        z=sorted_vertices[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=['rgb({}, {}, {})'.format(r, g, b) for r, g, b in rgb_colors],
        )
    )

    fig = go.Figure(data=[trace])
    fig.update_layout(
        title="Interactive 3D Rainbow Scatter",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z')
    )

    fig.show()

In [None]:
# Normalize scale of fishes 
import pandas as pd 
file_path = "fish_asset_color_ranking_kmeans.csv"
small_fishes = []


df = pd.read_csv(file_path)
file_names = [os.path.splitext(name)[0]+"_smpl.glb" for name in df["file_name"]]
file_names = [os.path.join(DATA_DIR, name) for name in file_names]
for i, file in enumerate(file_names):
    small_fish = trimesh.load(file)
    small_fish = small_fish.dump()
    small_fish = trimesh.util.concatenate(small_fish)
    small_fishes.append(small_fish)
    # small_fish.show()


In [None]:
# Direction modification
import numpy as np

def get_main_axis(mesh):
    reduced_vertices = fps_sample_vertices(mesh, max_samples=100)
    centered = reduced_vertices - reduced_vertices.mean(axis=0)
    U, S, Vt = np.linalg.svd(centered)
    axis = Vt[0]  # 가장 큰 주성분 → 길이 방향
    return axis  # 방향은 정해지지 않음 (머리인지 꼬리인지 아직 모름)


In [None]:
def project_along_axis(mesh, axis):
    proj = mesh.vertices @ axis
    return proj  # 각 vertex가 axis 위에 어느 위치에 있는지를 나타냄


In [None]:
def detect_head_direction(mesh, axis, ratio=0.05):
    proj = project_along_axis(mesh, axis)
    min_proj, max_proj = np.min(proj), np.max(proj)

    # 양 끝의 vertex subset
    head_side = mesh.vertices[(proj > max_proj - (max_proj - min_proj) * ratio)]
    tail_side = mesh.vertices[(proj < min_proj + (max_proj - min_proj) * ratio)]

    # 단순히 각 side의 vertex 분산으로 비교
    head_var = np.var(head_side @ axis)
    tail_var = np.var(tail_side @ axis)

    # 더 넓은 쪽을 head로 판단
    if head_var > tail_var:
        return axis  # 현재 axis 방향이 head→tail
    else:
        return -axis  # 반대로 뒤집음

In [None]:
def get_fish_forward_vector(mesh):
    axis = get_main_axis(mesh)
    "Done calculating axis!!"
    direction = detect_head_direction(mesh, axis)
    "Done calculating direction!!"
    return direction  # 방향 벡터

In [None]:
forward_vec = get_fish_forward_vector(reference_fish)  # 머리 → 꼬리 방향 벡터

In [None]:
import plotly.graph_objects as go
def visualize_direction_vec_w_fish(fish_vertices, vector):
    # 벡터 시작점 (예: 원점)
    x0, y0, z0 = 0, 0, 0

    # 벡터 방향 (예: x=1, y=2, z=3)
    u, v, w = vector[0], vector[1], vector[2]
    fish_vertices = fps_sample_vertices(fish_vertices, 100)
    vertices = np.array(fish_vertices)
    sorted_indices = np.argsort(vertices[:, 0])
    sorted_vertices = vertices[sorted_indices]

    # Rainbow color mapping
    norm = np.linspace(0, 1, len(sorted_vertices))
    colors = cm.rainbow(norm)
    rgb_colors = (colors[:, :3] * 255).astype(int)  # Convert to RGB

    trace = go.Scatter3d(
        x=sorted_vertices[:, 0],
        y=sorted_vertices[:, 1],
        z=sorted_vertices[:, 2],
        mode='markers',
        marker=dict(
            size=3,
            color=['rgb({}, {}, {})'.format(r, g, b) for r, g, b in rgb_colors],
        )
    )

    fig = go.Figure(data=[trace, go.Cone(
        x=[x0], y=[y0], z=[z0],
        u=[u], v=[v], w=[w],
        sizemode="absolute",
        sizeref=0.5,
        anchor="tail",  # 화살표의 tail이 시작점
        showscale=False,
        colorscale='Blues'
    )])
    fig.update_layout(
        title="Interactive 3D Rainbow Scatter",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z')
    )

    fig.show()

In [None]:
idx = 300
specific_fish = small_fishes[idx]
dir_vec = get_fish_forward_vector(specific_fish)  #
visualize_direction_vec_w_fish(specific_fish, dir_vec)

In [None]:
visualize_direction_vec_w_fish(reference_fish, forward_vec)

In [None]:
from trimesh.registration import icp
ref_points = reference_fish.vertices
small_fishes_aligned = []
for small_fish in tqdm(small_fishes):
    # 2. Source Mesh 로드 및 Vertex 추출
    small_fish_points = small_fish.vertices

    # 3. ICP 정렬 (point-to-point)
    matrix, _, cost = icp(
        small_fish_points,
        ref_points,
    )

    # 4. 변환 행렬 적용
    small_fish.apply_transform(matrix)

    # 5. 저장
    small_fishes_aligned.append(small_fish)

In [None]:
# Place small fish instances
scene = trimesh.Scene()
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]

for small_fish, p in tqdm(zip(small_fishes, sorted_points), desc='Instance Placement'):
    small_fish.apply_scale(SCALE_RATIO)
    small_fish.apply_translation(p)
    scene.add_geometry(small_fish)
scene.export('school_of_different_fishes_aligned_reduced.glb')

In [None]:
from scipy.spatial.transform import Rotation as R

def rotation_matrix_from_vectors(vec1, vec2):
    rotation, _ = R.align_vectors([vec2], [vec1])  # vec1 → vec2로 맞춤
    return rotation.as_matrix()

def rotate_mesh(mesh, R_mat):
    mesh.vertices = np.dot(mesh.vertices, R_mat.T)
    return mesh
def rotation_matrix_axis_angle(axis, angle):
    """
    회전 축(axis)과 회전 각도(angle)를 받아 3x3 회전 행렬을 반환합니다.
    axis: (3,) numpy array - 단위 벡터로 정규화된 회전 축
    angle: float - 라디안 단위의 회전 각도
    """
    axis = axis / np.linalg.norm(axis)  # 단위 벡터화
    x, y, z = axis
    cos_theta = np.cos(angle)
    sin_theta = np.sin(angle)
    one_minus_cos = 1.0 - cos_theta

    R = np.array([
        [cos_theta + x * x * one_minus_cos,
         x * y * one_minus_cos - z * sin_theta,
         x * z * one_minus_cos + y * sin_theta],

        [y * x * one_minus_cos + z * sin_theta,
         cos_theta + y * y * one_minus_cos,
         y * z * one_minus_cos - x * sin_theta],

        [z * x * one_minus_cos - y * sin_theta,
         z * y * one_minus_cos + x * sin_theta,
         cos_theta + z * z * one_minus_cos]
    ])
    return R

In [None]:
def get_rotation_matrix_from_vectors(a, b):
    a = a / np.linalg.norm(a)
    b = b / np.linalg.norm(b)
    v = np.cross(a, b)
    c = np.dot(a, b)
    if np.isclose(c, 1.0):  # 이미 정렬됨
        return np.eye(3)
    elif np.isclose(c, -1.0):  # 정반대 방향이면 특수 처리 필요
        # 벡터 a와 직교하는 벡터 하나를 임의로 선택
        ortho = np.array([1, 0, 0]) if abs(a[0]) < 0.9 else np.array([0, 1, 0])
        v = np.cross(a, ortho)
        v = v / np.linalg.norm(v)
        return rotation_matrix_axis_angle(v, np.pi)
    s = np.linalg.norm(v)
    kmat = np.array([[  0, -v[2],  v[1]],
                     [ v[2],   0, -v[0]],
                     [-v[1], v[0],   0]])
    R = np.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s**2))
    return R
def rotate_mesh_vertices(mesh, R_mat):
    mesh.vertices = mesh.vertices @ R_mat.T  # 회전 적용
    return mesh

In [None]:
for i, small_fish in tqdm(enumerate(small_fishes), desc="Aligning Small Fish"):
    src_dir = get_fish_forward_vector(small_fish)
    R_mat = get_rotation_matrix_from_vectors(src_dir, forward_vec)

    small_fishes[i] = rotate_mesh_vertices(small_fish, R_mat)

In [None]:
# Place small fish instances
scene = trimesh.Scene()
sorted_indices = np.argsort(points[:, 0])
sorted_points = points[sorted_indices]

for small_fish, p in tqdm(zip(small_fishes, sorted_points), desc='Instance Placement'):
    small_fish.apply_scale(SCALE_RATIO)
    small_fish.apply_translation(p)
    scene.add_geometry(small_fish)
scene.export('school_of_different_fishes_aligned.glb')