In [1]:
from extract_edge_features import extract_edge_features

import os
import pandas as pd
import trimesh
import plotly.graph_objects as go
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [3]:
class2colors = {'ankles': 'red', 'knees': 'blue', 'hips': 'green', 'shoulders': 'purple', 'elbows': 'orange', 'wrists': 'yellow', 'neck': 'brown'}

In [4]:
dataset_dir = 'datasets/mesh_skeleton_extraction'
mesh_train_dir = os.path.join(dataset_dir, 'mesh_train')
mesh_test_dir = os.path.join(dataset_dir, 'mesh_test')
keypoints_train_dir = os.path.join(dataset_dir, 'keypoints_train')
model_save_dir = 'models/'

num_edges = 1024
input_channels = 5
num_features = 7
num_keypoints = 12

batch_size = 1
learning_rate = 0.0001
num_epochs = 10

In [5]:
mesh = trimesh.load(os.path.join(mesh_train_dir, 'mesh_0000.obj'))
keypoints_df = pd.read_csv(os.path.join(keypoints_train_dir, 'keypoints_0000.csv'))
keypoints = keypoints_df[['x', 'y', 'z', 'class']].values

centroid = mesh.centroid
scale = mesh.scale

mesh.vertices -= centroid
mesh.vertices /= scale

keypoints[:, :3] -= centroid
keypoints[:, :3] /= scale

fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))

# plot mesh
fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

# # plot keypoints
for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color=class2colors[keypoint[3]])))
    
fig.show()

In [6]:
class MeshData(Dataset):
  def __init__(self, mesh_dir, keypoints_dir, device, num_edges=1024, normalize=True, only_keypoints_coordinates=False, only_keypoints_classes=False):
    self.mesh_dir = mesh_dir
    self.keypoints_dir = keypoints_dir
    self.device = device
    self.num_edges = num_edges
    self.normalize = normalize  
    
    if only_keypoints_coordinates and only_keypoints_classes: 
      raise ValueError('only_keypoints_coordinates and only_keypoints_classes cannot be True at the same time')
    self.only_keypoints_coordinates = only_keypoints_coordinates  
    self.only_keypoints_classes = only_keypoints_classes

    self.mesh_files = sorted(os.listdir(mesh_dir))
    self.keypoints_files = sorted(os.listdir(keypoints_dir))
    
  def __len__(self):
    return len(self.mesh_files)
  
  def __getitem__(self, idx):
    mesh_path = os.path.join(self.mesh_dir, self.mesh_files[idx]) 
    mesh_obj = trimesh.load(mesh_path)
    
    edge_features = extract_edge_features(mesh_obj, normalize=self.normalize, device=self.device)
    if edge_features.shape[0] > self.num_edges:
      edge_features = edge_features[:self.num_edges, :]
    else:
      padding = torch.zeros(self.num_edges - edge_features.shape[0], edge_features.shape[1]).to(self.device)  
      edge_features = torch.cat([edge_features, padding], dim=0)
    
    if self.only_keypoints_coordinates or self.only_keypoints_classes:
      if self.only_keypoints_coordinates: keypoints = pd.read_csv(os.path.join(self.keypoints_dir, self.keypoints_files[idx]))[['x', 'y', 'z']].values
      if self.only_keypoints_classes: keypoints = pd.read_csv(os.path.join(self.keypoints_dir, self.keypoints_files[idx]))[['class']].values
    else:
      keypoints = pd.read_csv(os.path.join(self.keypoints_dir, self.keypoints_files[idx]))[['x', 'y', 'z', 'class']].values

    if self.normalize:    
      centroid = mesh_obj.centroid
      scale = mesh_obj.scale
      mesh_obj.vertices -= centroid
      mesh_obj.vertices /= scale
      if self.only_keypoints_coordinates or (not self.only_keypoints_coordinates and not self.only_keypoints_classes):
        keypoints[:, :3] -= centroid
        keypoints[:, :3] /= scale

    return mesh_obj, edge_features, keypoints

In [7]:
def custom_collate_fn(batch):
    meshes = []
    edge_features_list = []
    keypoints_list = []

    for mesh_obj, edge_features, keypoints in batch:
        meshes.append(mesh_obj)
        edge_features_list.append(edge_features)
        keypoints_list.append(keypoints)

    edge_features_batch = torch.stack(edge_features_list)
    keypoints_batch = torch.stack([torch.tensor(kp) for kp in keypoints_list])

    return meshes, edge_features_batch, keypoints_batch

In [8]:
train_set = MeshData(mesh_train_dir, keypoints_train_dir, device=device, num_edges=num_edges, normalize=True, only_keypoints_coordinates=True, only_keypoints_classes=False)
train_set, val_set = torch.utils.data.random_split(train_set, [int(0.9*len(train_set)), len(train_set)-int(0.9*len(train_set))])
test_set = MeshData(mesh_test_dir, keypoints_train_dir, device=device, num_edges=num_edges, normalize=True, only_keypoints_coordinates=True, only_keypoints_classes=False)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

In [9]:
class KeypointPredictorCNN(nn.Module):
    def __init__(self, input_channels=5, num_features=7, num_keypoints=12):
        super(KeypointPredictorCNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=8, kernel_size=(3, 3), padding=1)  
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=1)      
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        
        self.fc1 = nn.Linear(196608, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 64) 
        self.fc_last = nn.Linear(64, num_keypoints * 3)  

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        
        x = self.relu(self.conv1(x)) # (batch_size, 8, num_edges, num_features)
        x = self.relu(self.conv2(x)) # (batch_size, 16, num_edges/2, num_features/2)       
        x = self.relu(self.conv3(x)) # (batch_size, 32, num_edges/4, num_features/4)
        x = self.relu(self.conv4(x)) # (batch_size, 64, num_edges/8, num_features/8)
        x = self.relu(self.conv5(x)) # (batch_size, 128, num_edges/16, num_features/16)
        x = self.pool(x)             # (batch_size, 128, num_edges/32, num_features/32) 
        x = x.view(x.size(0), -1)    # (batch_size, 128 * num_edges/32 * num_features/32)
        
        x = self.relu(self.fc1(x))   # (batch_size, 1024)
        x = self.dropout(x)         
        x = self.relu(self.fc2(x))   # (batch_size, 512)
        x = self.dropout(x)
        x = self.relu(self.fc3(x))   # (batch_size, 256)
        x = self.dropout(x)
        x = self.relu(self.fc4(x))   # (batch_size, 128)
        x = self.dropout(x)
        x = self.relu(self.fc5(x))   # (batch_size, 64)
        x = self.dropout(x)
        x = self.fc_last(x)          # (batch_size, num_keypoints * 3)
        
        x = x.view(x.size(0), -1, 3) # (batch_size, num_keypoints, 3)

        return x


In [10]:
keypoint_predictor = KeypointPredictorCNN(input_channels=input_channels, num_features=num_features, num_keypoints=num_keypoints).to(device)
optimizer = optim.Adam(keypoint_predictor.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [None]:
for epoch in range(num_epochs):
    keypoint_predictor.train()
    train_loss = 0.0
    for _, inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        inputs, labels = inputs.to(torch.float32).to(device), labels.to(torch.float32).to(device)
        optimizer.zero_grad()
        outputs = keypoint_predictor(inputs)
        norms = torch.norm(outputs - labels, dim=2)
        mse = torch.mean(norms**2)
        print(mse)
        loss = criterion(outputs, labels)
        print(loss)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        print('-'*50)
    train_loss = train_loss/len(train_loader)

    keypoint_predictor.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for _, inputs, labels in valid_loader:
            inputs, labels = inputs.to(torch.float32).to(device), labels.to(torch.float32).to(device)
            outputs = keypoint_predictor(inputs)
            norms = torch.norm(outputs - labels, dim=2)
            mse = torch.mean(norms**2)
            print(mse)
            loss = criterion(outputs, labels)
            print(loss)
            valid_loss += loss.item()
            print('-'*50)
    valid_loss = valid_loss/len(valid_loader)
    
    print(f'Train Loss: {train_loss}, Valid Loss: {valid_loss}')
    torch.save(keypoint_predictor.state_dict(), model_save_dir + f'checkpoints/keypoint_predictor_{epoch}.pth')

torch.save(keypoint_predictor.state_dict(), model_save_dir + 'keypoint_predictor.pth')

Epoch 1/10 - Training:  12%|█▎        | 1/8 [00:07<00:52,  7.49s/it]

tensor(921.8857, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(307.2952, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  25%|██▌       | 2/8 [00:14<00:44,  7.35s/it]

tensor(0.9587, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3196, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  38%|███▊      | 3/8 [00:22<00:36,  7.36s/it]

tensor(1.1447, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3816, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  50%|█████     | 4/8 [00:29<00:29,  7.38s/it]

tensor(0.9777, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3259, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  62%|██████▎   | 5/8 [00:36<00:21,  7.32s/it]

tensor(0.8225, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2742, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  75%|███████▌  | 6/8 [00:43<00:14,  7.23s/it]

tensor(0.8723, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2908, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training:  88%|████████▊ | 7/8 [00:51<00:07,  7.25s/it]

tensor(1.0037, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3346, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 1/10 - Training: 100%|██████████| 8/8 [00:57<00:00,  7.18s/it]

tensor(0.9659, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3220, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------





tensor(1.0185, device='cuda:0')
tensor(0.3395, device='cuda:0')
--------------------------------------------------
Train Loss: 38.692964643239975, Valid Loss: 0.3394979238510132


Epoch 2/10 - Training:  12%|█▎        | 1/8 [00:07<00:53,  7.62s/it]

tensor(0.9640, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3213, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  25%|██▌       | 2/8 [00:15<00:45,  7.58s/it]

tensor(1.1055, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3685, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  38%|███▊      | 3/8 [00:22<00:38,  7.61s/it]

tensor(1.0656, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3552, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  50%|█████     | 4/8 [00:30<00:30,  7.58s/it]

tensor(921.4128, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(307.1376, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  62%|██████▎   | 5/8 [00:37<00:22,  7.59s/it]

tensor(0.8700, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2900, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  75%|███████▌  | 6/8 [00:45<00:15,  7.53s/it]

tensor(0.8341, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2780, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training:  88%|████████▊ | 7/8 [00:52<00:07,  7.48s/it]

tensor(0.8538, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2846, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 2/10 - Training: 100%|██████████| 8/8 [00:59<00:00,  7.39s/it]

tensor(0.9572, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3191, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------





tensor(0.9937, device='cuda:0')
tensor(0.3312, device='cuda:0')
--------------------------------------------------
Train Loss: 38.66929629072547, Valid Loss: 0.3312237858772278


Epoch 3/10 - Training:  12%|█▎        | 1/8 [00:07<00:50,  7.28s/it]

tensor(0.9178, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3059, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  25%|██▌       | 2/8 [00:14<00:44,  7.40s/it]

tensor(0.9762, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3254, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  38%|███▊      | 3/8 [00:22<00:37,  7.43s/it]

tensor(0.9338, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3113, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  50%|█████     | 4/8 [00:29<00:29,  7.40s/it]

tensor(0.7205, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2402, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  62%|██████▎   | 5/8 [00:37<00:22,  7.42s/it]

tensor(0.9094, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3031, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  75%|███████▌  | 6/8 [00:44<00:14,  7.39s/it]

tensor(0.9111, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3037, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training:  88%|████████▊ | 7/8 [00:51<00:07,  7.38s/it]

tensor(920.9849, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(306.9950, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 3/10 - Training: 100%|██████████| 8/8 [00:58<00:00,  7.30s/it]

tensor(0.9839, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3280, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------





tensor(0.9636, device='cuda:0')
tensor(0.3212, device='cuda:0')
--------------------------------------------------
Train Loss: 38.63906716182828, Valid Loss: 0.32119807600975037


Epoch 4/10 - Training:  12%|█▎        | 1/8 [00:07<00:51,  7.39s/it]

tensor(0.8497, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2832, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  25%|██▌       | 2/8 [00:14<00:44,  7.46s/it]

tensor(1.1267, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3756, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  38%|███▊      | 3/8 [00:22<00:37,  7.41s/it]

tensor(0.9658, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3219, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  50%|█████     | 4/8 [00:29<00:29,  7.37s/it]

tensor(920.3573, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(306.7857, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  62%|██████▎   | 5/8 [00:36<00:21,  7.31s/it]

tensor(0.7964, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2655, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  75%|███████▌  | 6/8 [00:43<00:14,  7.20s/it]

tensor(0.9006, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3002, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training:  88%|████████▊ | 7/8 [00:50<00:07,  7.16s/it]

tensor(0.9165, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3055, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 4/10 - Training: 100%|██████████| 8/8 [00:57<00:00,  7.13s/it]

tensor(0.7759, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2586, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------





tensor(0.9271, device='cuda:0')
tensor(0.3090, device='cuda:0')
--------------------------------------------------
Train Loss: 38.612031769007444, Valid Loss: 0.30904126167297363


Epoch 5/10 - Training:  12%|█▎        | 1/8 [00:07<00:53,  7.64s/it]

tensor(0.9167, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3056, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  25%|██▌       | 2/8 [00:14<00:44,  7.40s/it]

tensor(0.7799, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2600, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  38%|███▊      | 3/8 [00:22<00:36,  7.39s/it]

tensor(0.9623, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3208, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  50%|█████     | 4/8 [00:29<00:29,  7.34s/it]

tensor(0.9689, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3230, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  62%|██████▎   | 5/8 [00:36<00:21,  7.24s/it]

tensor(0.8094, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2698, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  75%|███████▌  | 6/8 [00:44<00:14,  7.39s/it]

tensor(0.9389, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3130, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training:  88%|████████▊ | 7/8 [00:51<00:07,  7.41s/it]

tensor(0.9963, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3321, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------


Epoch 5/10 - Training: 100%|██████████| 8/8 [00:58<00:00,  7.27s/it]

tensor(1051.1263, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(350.3755, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------





tensor(0.9029, device='cuda:0')
tensor(0.3010, device='cuda:0')
--------------------------------------------------
Train Loss: 44.06245068460703, Valid Loss: 0.3009769022464752


Epoch 6/10 - Training:  12%|█▎        | 1/8 [00:07<00:55,  7.96s/it]

tensor(0.8939, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.2980, device='cuda:0', grad_fn=<MseLossBackward0>)
--------------------------------------------------
