In [1]:
import os
import plotly.graph_objects as go
import trimesh
from create_labels_dataset import create_labels_for_mesh, create_useful_vertices_dict, find_connected_components

In [2]:
root_dir = 'datasets/human_seg'
train_path = os.path.join(root_dir, 'train')
seg_path = os.path.join(root_dir, 'seg_train')

mesh_files = [f for f in os.listdir(train_path) if f.endswith('.obj')]
mesh_files.sort()
seg_files = [f for f in os.listdir(seg_path) if f.endswith('.eseg')]
seg_files.sort()
pair_list = list(zip(mesh_files, seg_files))

i = 0
for mesh_file, seg_file in pair_list:
  
  edge_label_pairs = create_labels_for_mesh(train_path, seg_path, mesh_file, seg_file)
  all_labels = set()
  for edge_label_pair in edge_label_pairs:
    all_labels.add(edge_label_pair[1])
  
  useful_vertices_dict = create_useful_vertices_dict(edge_label_pairs)

  useful_pairs = [elem for elem in edge_label_pairs if elem[0][0] in useful_vertices_dict or elem[0][1] in useful_vertices_dict]
  print(useful_pairs)
  connected_components = find_connected_components(useful_pairs)
  
  mesh = trimesh.load_mesh(os.path.join(train_path, mesh_file))
  vertices = mesh.vertices
  connected_components_vertices = []
  for connected_component in connected_components:
    connected_component_vertices = set()
    for vertex in connected_component:
      connected_component_vertices.add(vertices[vertex])   
    connected_components_vertices.append(connected_component_vertices)

  keypoints = []
  for connected_component_vertices in connected_components_vertices:
    # compute the mean of the connected component
    mean = [0, 0, 0]
    for vertex in connected_component_vertices:
      mean[0] += vertex[0]
      mean[1] += vertex[1]
      mean[2] += vertex[2]
    mean[0] /= len(connected_component_vertices)
    mean[1] /= len(connected_component_vertices)
    mean[2] /= len(connected_component_vertices)
    keypoints.append(mean)
  
  useful_vertices_coordinates = []
  for vertex in useful_vertices_dict:
    useful_vertices_coordinates.append(vertices[vertex])


  fig = go.Figure()
  
  # Add the mesh data
  fig.add_trace(go.Mesh3d(
      x=vertices[:, 0], 
      y=vertices[:, 1], 
      z=vertices[:, 2], 
      i=mesh.faces[:, 0], 
      j=mesh.faces[:, 1], 
      k=mesh.faces[:, 2], 
      opacity=0.5
  ))
  
  # Add the keypoints
  # for keypoint in keypoints:
  #     fig.add_trace(go.Scatter3d(
  #         x=[keypoint[0]], 
  #         y=[keypoint[1]], 
  #         z=[keypoint[2]], 
  #         mode='markers', 
  #         marker=dict(size=5)
  #     ))
  
  #plot the useful vertices
  # for useful_vertex in useful_vertices_coordinates:
  #     fig.add_trace(go.Scatter3d(
  #         x=[useful_vertex[0]], 
  #         y=[useful_vertex[1]], 
  #         z=[useful_vertex[2]], 
  #         mode='markers', 
  #         marker=dict(size=3)
  #     ))
      
  # plot connected components vertices with different colors (same color for the same connected component)
  colors = ['red', 'green', 'blue', 'yellow', 'purple', 'orange', 'pink', 'cyan', 'magenta', 'brown']
  # for i, connected_component_vertices in enumerate(connected_components_vertices):
  #   x = []
  #   y = []
  #   z = []
  #   for vertex in connected_component_vertices:
  #     x.append(vertex[0])
  #     y.append(vertex[1])
  #     z.append(vertex[2])
  #   fig.add_trace(go.Scatter3d(
  #       x=x, 
  #       y=y, 
  #       z=z, 
  #       mode='markers', 
  #       marker=dict(size=3, color=colors[i])
  #   ))
    
  # plot useful_pairs edges using one color for each class
  for pair in useful_pairs:
    edge = pair[0]
    cls = pair[1]
    v0 = vertices[edge[0]]
    v1 = vertices[edge[1]]
    fig.add_trace(go.Scatter3d(
        x=[v0[0], v1[0]], 
        y=[v0[1], v1[1]], 
        z=[v0[2], v1[2]], 
        mode='lines', 
        line=dict(width=2, color=colors[int(cls)])
    ))
      
  
  # Set the aspect ratio to ensure correct proportions
  fig.update_layout(
      scene=dict(
          aspectmode='data'
      )
  )
  
  # Show the plot
  fig.show()
  break
  

[((1, 53), '8'), ((18, 53), '8'), ((18, 48), '8'), ((40, 48), '8'), ((48, 53), '8'), ((47, 53), '8'), ((8, 56), '8'), ((45, 56), '8'), ((38, 56), '8'), ((25, 43), '8'), ((26, 43), '8'), ((43, 719), '8'), ((34, 49), '8'), ((40, 49), '8'), ((33, 56), '7'), ((33, 45), '8'), ((26, 39), '8'), ((29, 39), '8'), ((39, 43), '8'), ((53, 63), '8'), ((47, 63), '8'), ((36, 45), '8'), ((35, 36), '8'), ((31, 36), '8'), ((41, 54), '8'), ((38, 54), '8'), ((42, 49), '8'), ((52, 72), '7'), ((39, 52), '8'), ((39, 72), '8'), ((43, 72), '8'), ((31, 43), '8'), ((54, 56), '8'), ((36, 43), '8'), ((28, 51), '8'), ((41, 51), '8'), ((63, 713), '8'), ((47, 713), '8'), ((28, 704), '8'), ((29, 704), '8'), ((48, 60), '8'), ((40, 60), '8'), ((49, 60), '8'), ((52, 62), '7'), ((39, 62), '7'), ((59, 60), '7'), ((48, 59), '7'), ((49, 50), '8'), ((50, 710), '7'), ((49, 710), '7'), ((44, 55), '8'), ((42, 55), '8'), ((51, 54), '8'), ((39, 704), '8'), ((55, 58), '8'), ((60, 710), '7'), ((62, 704), '7'), ((36, 66), '7'), ((43,