In [1]:
from colordict import *
import numpy as np

color = ColorDict(norm=255, mode='rgb', palettes_path="", is_grayscale=False, palettes='all')
pallete = np.array(list(color.values()))[:, :3]

In [2]:
import os
import sys

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname('/home/ntphat/thuan/ICLR/eman/experiments/HUMAN/train_human.py'), os.path.pardir))
)

import copy
import os.path as osp
from typing import Callable, Optional
import numpy as np
import torch
import torch_geometric.transforms as T
import exp_utils
from egnn.transform.preprocess import *
from torch_geometric.data import Data, InMemoryDataset, DataLoader
import trimesh

class HUMAN(InMemoryDataset):
    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_transform_str="",
        pre_filter: Optional[Callable] = None,
        skip_process: bool = False,
    ):
        self.root = root
        self.pre_transform = pre_transform
        self.pre_transform_str = pre_transform_str
        self.skip_process = skip_process

        super().__init__(
            root, transform, pre_transform=pre_transform, pre_filter=pre_filter
        )

        path = self.processed_paths[not(train)]
        if not self.skip_process:
            self.data, self.slices = torch.load(path)

    @property
    def processed_file_names(self):
        base_paths = ["train.pt", "test.pt"]
        return [self.pre_transform_str + _ for _ in base_paths]

    def _process_mesh(self, data, tform, dlist):
        aux_data = copy.deepcopy(data)
        if tform is not None:
            aux_data = tform(aux_data)
        dlist.append(aux_data)

    def read_seg(self, seg):
        seg_labels = np.loadtxt(open(seg, 'r'), dtype='long')
        return seg_labels

    def process(self):
        data_list = [[], []]

        for i, split in enumerate(['train', 'test']):
            for path in os.listdir(os.path.join(self.root, f'raw/{split}')):
                mesh = trimesh.load(os.path.join(self.root, f'raw/{split}', path))
                pos = torch.tensor(mesh.vertices).float()
                face = torch.tensor(mesh.faces)
                face = face - face.min()  # Ensure zero-based index.
                
                label = self.read_seg(os.path.join(self.root, f'raw/seg/', path).replace('.ply', '.eseg'))

                data = Data(pos=pos, face=face.T.contiguous(), y=torch.tensor(label, dtype=torch.long))
                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                data_list[i].append(data)

        for dl, _path in zip(data_list, self.processed_paths):
            torch.save(self.collate(dl), _path)


pre_tform = T.Compose([compute_normals_edges_from_mesh,  compute_area_from_mesh, compute_triangle_info_from_mesh, compute_di_angle, compute_hks])
path = exp_utils.get_dataset_path("human_direct")
HUMAN.processed_dir = osp.join(path, "processed_direct")

test_dataset = HUMAN(path, train=False, pre_transform=pre_tform)

In [3]:
sample = test_dataset[1]
x = sample.pos[:, 0]
y = sample.pos[:, 1]
z = sample.pos[:, 2]
i = sample.face[0]
j = sample.face[1]
k = sample.face[2]
label = sample.y

In [4]:
pallete

array([[196.,  98.,  16.],
       [ 46.,  88., 148.],
       [156.,  37.,  66.],
       [191.,  79.,  81.],
       [165., 113., 100.],
       [ 88.,  66., 124.],
       [ 74., 100., 108.],
       [133., 117.,  78.],
       [ 49., 145., 119.],
       [ 10., 126., 140.],
       [156., 124.,  56.],
       [141.,  78., 133.],
       [143., 212.,   0.],
       [217., 134., 149.],
       [117., 117., 117.],
       [255.,  53.,  94.],
       [253.,  91., 120.],
       [255.,  96.,  55.],
       [255., 153., 102.],
       [255., 153.,  51.],
       [255., 204.,  51.],
       [255., 255., 102.],
       [204., 255.,   0.],
       [102., 255., 102.],
       [170., 240., 209.],
       [ 80., 191., 230.],
       [255., 110., 255.],
       [238.,  52., 210.],
       [255.,   0., 204.],
       [254., 254., 250.],
       [255., 209.,  42.],
       [ 79., 134., 247.],
       [255., 211., 248.],
       [201.,  90.,  73.],
       [218.,  38.,  71.],
       [254., 254., 254.],
       [255., 255.,  49.],
 

In [5]:
label

tensor([0, 0, 0,  ..., 7, 7, 7])

In [6]:
triangles = np.vstack((i,j,k)).T
vertices = np.vstack((x,y,z)).T
tri_points = vertices[triangles]

Xe = []
Ye = []
Ze = []
for T in tri_points:
    Xe.extend([T[k%3][0] for k in range(4)]+[ None])
    Ye.extend([T[k%3][1] for k in range(4)]+[ None])
    Ze.extend([T[k%3][2] for k in range(4)]+[ None])

    

In [None]:
import plotly.graph_objects as go
import numpy as np

mesh =  go.Mesh3d(
        x = x,
        y = y,
        z = z,
        i = i,
        j = j,
        k = k,
        facecolor=pallete[label], 
        )
line = go.Scatter3d(
                   x=Xe,
                   y=Ye,
                   z=Ze,
                   mode='lines',
                   name='',
                   line=dict(color= 'rgb(70,70,70)', width=1))

fig = go.Figure(data=[mesh, line])
fig.update_scenes(aspectmode='data')
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
fig.show()

In [10]:
# model = torch.load('/home/ntphat/thuan/ICLR/eman/results/human/trained_model_best.h5').cuda()