In [10]:
%cd /ibex/user/slimhy/PADS/code
"""
Visualizing 3D mesh with oriented bounding boxes for each part using matplotlib colors.
Includes manual camera control and zoom functionality.
"""
from datasets.CoMPaT.compat3D import SegmentedMeshLoader
from datasets.metadata import class_to_idx
from util.misc import get_bb_vecs
import k3d
import numpy as np
import matplotlib.cm as cm



# Constants
ZIP_PATH = "/ibex/project/c2273/3DCoMPaT/3DCoMPaT_ZIP.zip"
META_DIR = "/ibex/project/c2273/3DCoMPaT/3DCoMPaT-v2/metadata"


def create_rotation_matrix(deg_x, deg_y, deg_z):
    """
    Create a 3D rotation matrix from angles in degrees around each axis.
    
    Args:
        deg_x (float): Rotation angle around X-axis in degrees
        deg_y (float): Rotation angle around Y-axis in degrees
        deg_z (float): Rotation angle around Z-axis in degrees
    
    Returns:
        np.ndarray: 3x3 rotation matrix
    """
    # Convert degrees to radians
    theta_x = np.deg2rad(deg_x)
    theta_y = np.deg2rad(deg_y)
    theta_z = np.deg2rad(deg_z)
    
    # Rotation matrix around X axis
    Rx = np.array([
        [1, 0, 0],
        [0, np.cos(theta_x), -np.sin(theta_x)],
        [0, np.sin(theta_x), np.cos(theta_x)]
    ])
    
    # Rotation matrix around Y axis
    Ry = np.array([
        [np.cos(theta_y), 0, np.sin(theta_y)],
        [0, 1, 0],
        [-np.sin(theta_y), 0, np.cos(theta_y)]
    ])
    
    # Rotation matrix around Z axis
    Rz = np.array([
        [np.cos(theta_z), -np.sin(theta_z), 0],
        [np.sin(theta_z), np.cos(theta_z), 0],
        [0, 0, 1]
    ])
    
    # Combined rotation matrix (order: Z -> Y -> X)
    return Rx @ Ry @ Rz


def transform_mesh(mesh_map, rotation_angles=(0, 0, 0), scale=1.0, center=None):
    """
    Apply rotation and scaling transformations to all meshes in a mesh map.
    
    Args:
        mesh_map (dict): Dictionary of mesh parts
        rotation_angles (tuple): Rotation angles (theta_x, theta_y, theta_z) in radians
        scale (float): Uniform scaling factor
        center (np.ndarray, optional): Center of rotation. If None, uses mesh centroid
        
    Returns:
        dict: Transformed mesh map
    """
    transformed_map = {}
    
    # If no center is provided, compute centroid of all vertices
    if center is None:
        all_vertices = []
        for part_mesh in mesh_map.values():
            if isinstance(part_mesh, list):
                for submesh in part_mesh:
                    all_vertices.extend(submesh.vertices)
            else:
                all_vertices.extend(part_mesh.vertices)
        center = np.mean(all_vertices, axis=0)
    
    # Create rotation matrix
    R = create_rotation_matrix(*rotation_angles)
    
    # Process each mesh in the map
    for part_name, part_mesh in mesh_map.items():
        if isinstance(part_mesh, list):
            # Handle multiple submeshes for a part
            transformed_submeshes = []
            for submesh in part_mesh:
                # Create a copy of the submesh
                new_submesh = submesh.copy()
                
                # Center vertices
                vertices = np.array(new_submesh.vertices) - center
                
                # Apply rotation and scaling
                vertices = scale * (vertices @ R.T)
                
                # Move back to original position
                new_submesh.vertices = vertices + center
                transformed_submeshes.append(new_submesh)
            
            transformed_map[part_name] = transformed_submeshes
        else:
            # Handle single mesh
            new_mesh = part_mesh.copy()
            
            # Center vertices
            vertices = np.array(new_mesh.vertices) - center
            
            # Apply rotation and scaling
            vertices = scale * (vertices @ R.T)
            
            # Move back to original position
            new_mesh.vertices = vertices + center
            transformed_map[part_name] = new_mesh
    
    return transformed_map

def rgb_to_hex(rgb):
    """Convert RGB values (0-1) to hex color code."""
    rgb_int = [int(x * 255) for x in rgb[:3]]
    return int('0x{:02x}{:02x}{:02x}'.format(*rgb_int), 16)

def add_bbox_visualization(plot, mesh, color, plot_points=False):
    """Add bounding box visualization for a mesh."""
    bb_prim = mesh.bounding_box_oriented
    centroid, vecs = get_bb_vecs(bb_prim)
    
    # Add bounding box mesh with a slightly darker version of the part color
    bbox_color = rgb_to_hex(np.array(color) * 0.7)
    plot += k3d.mesh(bb_prim.vertices, bb_prim.faces, color=bbox_color, opacity=0.2)
    
    if plot_points:
        # Add bounding box vertices
        for points in bb_prim.vertices:
            plot += k3d.points(points, point_size=0.01, color=0xff0000, opacity=0.95)
        # Add centroid
        plot += k3d.points(centroid, point_size=0.02, color=0x00ff00)
    
    return plot

def setup_camera(plot, target_point=None, camera_position=None, up_vector=None, zoom_factor=1.0):
    """
    Set up camera parameters for the visualization.
    
    Args:
        plot: k3d plot object
        target_point (np.ndarray): Point the camera looks at (default: center of scene)
        camera_position (np.ndarray): Position of the camera (default: automatic)
        up_vector (np.ndarray): Camera up vector (default: [0, 1, 0])
        zoom_factor (float): Zoom level (default: 1.0)
    """
    if target_point is None:
        # Calculate scene center as default target
        all_vertices = []
        for obj in plot.objects:
            if hasattr(obj, 'vertices'):
                all_vertices.extend(obj.vertices)
        if all_vertices:
            target_point = np.mean(all_vertices, axis=0)
        else:
            target_point = np.array([0, 0, 0])

    if camera_position is None:
        # Set default camera position relative to target
        camera_position = target_point + np.array([2, 2, 2])

    if up_vector is None:
        up_vector = np.array([0, 1, 0])

    # Normalize vectors
    camera_position = np.array(camera_position)
    target_point = np.array(target_point)
    up_vector = np.array(up_vector) / np.linalg.norm(up_vector)

    # Calculate camera distance and adjust for zoom
    camera_distance = np.linalg.norm(camera_position - target_point)
    camera_direction = (camera_position - target_point) / camera_distance
    adjusted_distance = camera_distance / zoom_factor
    adjusted_position = target_point + camera_direction * adjusted_distance

    # Set camera parameters
    plot.camera = [
        *adjusted_position,  # camera position
        *target_point,      # target point
        *up_vector         # up vector
    ]
    
    return plot

def visualize_mesh_with_bbox(mesh_map, colormap='viridis', cut_n_parts=0):
    """
    Create a visualization of mesh parts with bounding boxes and custom camera settings.
    
    Args:
        mesh_map: Dictionary of mesh parts
        colormap (str): Name of matplotlib colormap to use
        camera_params (dict): Optional camera parameters including:
            - target_point: Point camera looks at
            - camera_position: Position of camera
            - up_vector: Camera up direction
            - zoom_factor: Zoom level
    """
    # Create the plot with minimal visual elements
    plot = k3d.plot(
        grid_visible=False,
        axes_helper=0,
        camera_auto_fit=False,
        grid=[1, 1, 1, 1, 1, 1],
        camera_fov=45,
        camera_zoom_speed=8.,
        height=1024,
    )

    # Get number of parts for color mapping
    n_parts = len(mesh_map)
    cmap = cm.get_cmap(colormap)

    # Add each part with its bounding box
    for part_idx, part_mesh in list(enumerate(mesh_map.values()))[:len(mesh_map) - cut_n_parts]:
        color = cmap(part_idx / n_parts)
        hex_color = rgb_to_hex(color)
        
        if isinstance(part_mesh, list):
            for submesh in part_mesh:
                plot += k3d.mesh(submesh.vertices, submesh.faces, 
                               color=hex_color, 
                               opacity=0.8)
                plot = add_bbox_visualization(plot, submesh, color)
        else:
            plot += k3d.mesh(part_mesh.vertices, part_mesh.faces, 
                            color=hex_color, 
                            opacity=0.8)
            plot = add_bbox_visualization(plot, part_mesh, color)

    return plot

# Instantiate dataset
seg_dataset = SegmentedMeshLoader(
    filter_class=[class_to_idx("chair")],
    zip_path=ZIP_PATH,
    meta_dir=META_DIR,
    split="train",
    shuffle=True,
    get_instances=True,
    seed=0,
)

# Example rotation angles in radians
angles = (105, 5, -5)
scale_factor = 2
mesh_map = seg_dataset[1]

trans_map = create_rotation_matrix(*angles)
mesh_map = transform_mesh(mesh_map, rotation_angles=angles, scale=scale_factor)

# Create and display visualization
plot = visualize_mesh_with_bbox(mesh_map, colormap='viridis', cut_n_parts=4)
plot

/ibex/user/slimhy/PADS/code


  cmap = cm.get_cmap(colormap)


Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper_colors=[16711680, 65280, 255], background_color=16777215, …