In [29]:
from extract_edge_features import extract_edge_features

import os
import pandas as pd
import numpy as np
import trimesh
import plotly.graph_objects as go
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [31]:
dataset_dir = 'datasets/mesh_keypoints_extraction'
meshes_dir = os.path.join(dataset_dir, 'meshes')
keypoints_dir = os.path.join(dataset_dir, 'keypoints')
model_save_dir = 'models/'

num_edges = 1024
input_channels = 5
num_features = 7
num_keypoints = 1

batch_size = 16
learning_rate = 0.01
num_epochs = 10

In [32]:
index = f'0000'
mesh = trimesh.load(os.path.join(meshes_dir, f'mesh_{index}.obj'))
keypoints_df = pd.read_csv(os.path.join(keypoints_dir, f'keypoints_{index}.csv'))
keypoints = keypoints_df[['x', 'y', 'z']].values

variance = mesh.vertices.var(axis=0)
order = variance.argsort()[::-1]
mesh.vertices = mesh.vertices[:, order]
keypoints[:, :3] = keypoints[:, :3][:, order]

centroid = mesh.centroid
scale = mesh.scale

mesh.vertices -= centroid
mesh.vertices /= scale

keypoints[:, :3] -= centroid
keypoints[:, :3] /= scale

fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))

# plot mesh
fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

# # plot keypoints
for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue', opacity=0.5)))
    
fig.show()

In [6]:
class MeshData(Dataset):
  def __init__(self, mesh_dir, keypoints_dir, device, num_edges=1024, normalize=True):
    self.mesh_dir = mesh_dir
    self.keypoints_dir = keypoints_dir
    self.device = device
    self.num_edges = num_edges
    self.normalize = normalize  
    
    self.mesh_files = sorted(os.listdir(mesh_dir))
    self.keypoints_files = sorted(os.listdir(keypoints_dir))
    
  def __len__(self):
    return len(self.mesh_files)
  
  def __getitem__(self, idx):
    #* Get mesh
    mesh_path = os.path.join(self.mesh_dir, self.mesh_files[idx]) 
    mesh_obj = trimesh.load(mesh_path)
    
    #* Order vertices based on variance
    variance = mesh_obj.vertices.var(axis=0)
    order = variance.argsort()[::-1]
    mesh_obj.vertices = mesh_obj.vertices[:, order]
    
    #* Get keypoints
    keypoints = pd.read_csv(os.path.join(self.keypoints_dir, self.keypoints_files[idx]))[['x', 'y', 'z']].values
    keypoints[:, :3] = keypoints[:, :3][:, order]
    keypoints = np.array([keypoints[0]])
    
    #* Normalize mesh and keypoints
    if self.normalize:    
        centroid = mesh_obj.centroid
        scale = mesh_obj.scale
        mesh_obj.vertices -= centroid
        mesh_obj.vertices /= scale
        if self.only_keypoints_coordinates or (not self.only_keypoints_coordinates and not self.only_keypoints_classes):
            keypoints[:, :3] -= centroid
            keypoints[:, :3] /= scale

    #* Get edge features
    edge_features = extract_edge_features(mesh_obj, device=self.device)
    if edge_features.shape[0] > self.num_edges:
      edge_features = edge_features[:self.num_edges, :]
    else:
      padding = torch.zeros(self.num_edges - edge_features.shape[0], edge_features.shape[1]).to(self.device)  
      edge_features = torch.cat([edge_features, padding], dim=0)
    
    return mesh_obj, edge_features, keypoints

In [7]:
def custom_collate_fn(batch):
    meshes = []
    edge_features_list = []
    keypoints_list = []

    for mesh_obj, edge_features, keypoints in batch:
        meshes.append(mesh_obj)
        edge_features_list.append(edge_features)
        keypoints_list.append(keypoints)

    edge_features_batch = torch.stack(edge_features_list)
    keypoints_batch = torch.stack([torch.tensor(kp) for kp in keypoints_list])

    return meshes, edge_features_batch, keypoints_batch

In [8]:
train_set = MeshData(mesh_train_dir, keypoints_train_dir, device=device, num_edges=num_edges, normalize=True, only_keypoints_coordinates=True, only_keypoints_classes=False)
train_set, val_set = torch.utils.data.random_split(train_set, [int(0.9*len(train_set)), len(train_set)-int(0.9*len(train_set))])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

In [9]:
mesh, edge_features, keypoints = train_set[0]

fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))
fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))
for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue')))
    
# fig.show()

In [10]:
class KeypointPredictorCNN(nn.Module):
    def __init__(self, input_channels=5, num_features=7, num_keypoints=12):
        super(KeypointPredictorCNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=8, kernel_size=(3, 3), padding=1)  
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=1)      
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        
        self.fc1 = nn.Linear(196608, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 64) 
        self.fc_last = nn.Linear(64, num_keypoints * 3)  

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        
        x = self.relu(self.conv1(x)) # (batch_size, 8, num_edges, num_features)
        x = self.relu(self.conv2(x)) # (batch_size, 16, num_edges/2, num_features/2)       
        x = self.relu(self.conv3(x)) # (batch_size, 32, num_edges/4, num_features/4)
        x = self.relu(self.conv4(x)) # (batch_size, 64, num_edges/8, num_features/8)
        x = self.relu(self.conv5(x)) # (batch_size, 128, num_edges/16, num_features/16)
        x = self.pool(x)             # (batch_size, 128, num_edges/32, num_features/32) 
        x = x.view(x.size(0), -1)    # (batch_size, 128 * num_edges/32 * num_features/32)
        
        x = self.relu(self.fc1(x))   # (batch_size, 1024)
        x = self.dropout(x)         
        x = self.relu(self.fc2(x))   # (batch_size, 512)
        x = self.dropout(x)
        x = self.relu(self.fc3(x))   # (batch_size, 256)
        x = self.dropout(x)
        x = self.relu(self.fc4(x))   # (batch_size, 128)
        x = self.dropout(x)
        x = self.relu(self.fc5(x))   # (batch_size, 64)
        x = self.dropout(x)
        x = self.fc_last(x)          # (batch_size, num_keypoints * 3)
        
        x = x.view(x.size(0), -1, 3) # (batch_size, num_keypoints, 3)

        # order keypoints by x coordinate
        ordered_x = []
        for i in range(x.size(0)):
            ordered_x.append(x[i][x[i, :, 0].argsort()])
        return torch.stack(ordered_x)

In [11]:
keypoint_predictor = KeypointPredictorCNN(input_channels=input_channels, num_features=num_features, num_keypoints=num_keypoints).to(device)
optimizer = optim.Adam(keypoint_predictor.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [12]:
for epoch in range(num_epochs):
    keypoint_predictor.train()
    train_loss = 0.0
    for _, inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        inputs, labels = inputs.to(torch.float32).to(device), labels.to(torch.float32).to(device)
        optimizer.zero_grad()
        outputs = keypoint_predictor(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss = train_loss/len(train_loader)

    keypoint_predictor.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for _, inputs, labels in valid_loader:
            inputs, labels = inputs.to(torch.float32).to(device), labels.to(torch.float32).to(device)
            outputs = keypoint_predictor(inputs)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
    valid_loss = valid_loss/len(valid_loader)
    
    print(f'Train Loss: {train_loss}, Valid Loss: {valid_loss}')
    torch.save(keypoint_predictor.state_dict(), model_save_dir + f'checkpoints/keypoint_predictor_{epoch}.pth')

torch.save(keypoint_predictor.state_dict(), model_save_dir + 'keypoint_predictor.pth')

Epoch 1/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.86s/it]


Train Loss: 144.50507819629274, Valid Loss: 1.1845605373382568


Epoch 2/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.77s/it]


Train Loss: 0.3393394532613456, Valid Loss: 0.03716079145669937


Epoch 3/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.77s/it]


Train Loss: 0.03980175592005253, Valid Loss: 0.039130326360464096


Epoch 4/10 - Training: 100%|██████████| 8/8 [00:55<00:00,  6.91s/it]


Train Loss: 0.10351924365386367, Valid Loss: 0.04650364816188812


Epoch 5/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.84s/it]


Train Loss: 0.05064139422029257, Valid Loss: 0.044000860303640366


Epoch 6/10 - Training: 100%|██████████| 8/8 [00:56<00:00,  7.03s/it]


Train Loss: 0.037510124035179615, Valid Loss: 0.03958660736680031


Epoch 7/10 - Training: 100%|██████████| 8/8 [00:55<00:00,  6.88s/it]


Train Loss: 0.03500244370661676, Valid Loss: 0.03544745594263077


Epoch 8/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.78s/it]


Train Loss: 0.033883966971188784, Valid Loss: 0.033612821251153946


Epoch 9/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.82s/it]


Train Loss: 0.03258819901384413, Valid Loss: 0.030618824064731598


Epoch 10/10 - Training: 100%|██████████| 8/8 [00:54<00:00,  6.79s/it]


Train Loss: 0.03309599915519357, Valid Loss: 0.03071964718401432


In [20]:
keypoint_predictor_test = KeypointPredictorCNN(input_channels=input_channels, num_features=num_features, num_keypoints=num_keypoints).to(device)
keypoint_predictor_test.load_state_dict(torch.load(model_save_dir + 'keypoint_predictor.pth', weights_only=True))

test_set = MeshData(mesh_test_dir, keypoints_train_dir, device=device, num_edges=num_edges, normalize=True, only_keypoints_coordinates=True, only_keypoints_classes=False)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

# mesh, edge_features, keypoints = test_set[1]
mesh, edge_features, keypoints = train_set[0] 
predicted_keypoints = keypoint_predictor_test(edge_features.unsqueeze(0).to(torch.float32).to(device)).squeeze().cpu().detach().numpy()
predicted_keypoints = [predicted_keypoints]


fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))

# plot mesh
fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

# # plot keypoints
for i, keypoint in enumerate(predicted_keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue')))
for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=2, color='red'))) 
    
fig.show()