In [1]:
from mesh_keypoints_extraction import KeypointPredictionNetwork, MeshData, train, test, test_single_mesh, custom_collate_fn, HungarianSumOfDistancesLoss
from mesh_keypoints_extraction import hungarian_mpjpe, hungarian_pck

import os
import numpy as np
import plotly.graph_objects as go

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import GradScaler

torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [3]:
dataset_dir = 'datasets/mesh_keypoints_extraction_dataset'
meshes_dir = os.path.join(dataset_dir, 'meshes')
keypoints_dir = os.path.join(dataset_dir, 'keypoints')
model_save_dir = 'weights/'

num_edges = 750
input_channels = 5
num_keypoints = 12

batch_size = 32
learning_rate = 0.001
num_epochs = 90

In [4]:
dataset = MeshData(meshes_dir, keypoints_dir, device=device, num_edges=num_edges, normalize=True)
train_set_size = int(0.8 * len(dataset))
val_set_size = int(0.1 * len(dataset))
test_set_size = len(dataset) - train_set_size - val_set_size
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_set_size, val_set_size, test_set_size])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

In [5]:
hungarian_sum_of_distances_loss = HungarianSumOfDistancesLoss()

In [6]:
keypoints_predictor = KeypointPredictionNetwork(input_channels=input_channels, num_keypoints=num_keypoints).to(device)
keypoints_predictor.load_state_dict(torch.load(model_save_dir + 'keypoints_predictor.pth', weights_only=True))
optimizer = optim.Adam(keypoints_predictor.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)
scaler = GradScaler()

train(keypoints_predictor, optimizer, hungarian_sum_of_distances_loss, scaler, scheduler, train_loader, valid_loader, num_epochs, device, model_save_dir)

In [None]:
keypoints_predictor_test = KeypointPredictionNetwork(input_channels=input_channels, num_keypoints=num_keypoints).to(device)
keypoints_predictor_test.load_state_dict(torch.load(model_save_dir + 'keypoints_predictor.pth', weights_only=True))

test(keypoints_predictor_test, test_loader, hungarian_sum_of_distances_loss, device)

In [None]:
mesh, edge_features, keypoints = test_set[0]
predicted_keypoints = test_single_mesh(keypoints_predictor_test, edge_features, keypoints, hungarian_sum_of_distances_loss, device)


fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))

fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

for i, keypoint in enumerate(predicted_keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue')))
keypoints = keypoints.cpu().detach().numpy()

for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=3, color='red')))

fig.show()

In [9]:
other_meshes_dir = os.path.join(dataset_dir, 'other_meshes')
other_meshes_dataset = MeshData(meshes_dir, keypoints_dir, device=device, num_edges=num_edges, normalize=True)

In [None]:
mesh, edge_features, keypoints = other_meshes_dataset[10]
predicted_keypoints = keypoints_predictor_test(edge_features.unsqueeze(0).to(torch.float32).to(device)).squeeze().cpu().detach().numpy()


fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))

fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

for i, keypoint in enumerate(predicted_keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue')))

fig.show()