In [1]:
from mesh_keypoints_extraction import KeypointPredictionNetwork, MeshData, train, test, custom_collate_fn, ChamferLoss, SumOfDistancesLoss, HungarianSumOfDistancesLoss

import os
import pandas as pd
import numpy as np
import trimesh
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [3]:
dataset_dir = 'mesh_keypoints_extraction_dataset'
meshes_dir = os.path.join(dataset_dir, 'meshes')
keypoints_dir = os.path.join(dataset_dir, 'keypoints')
model_save_dir = 'weights/'

num_edges = 750
input_channels = 5
num_keypoints = 12

batch_size = 32
learning_rate = 0.001
num_epochs = 90

In [4]:
dataset = MeshData(meshes_dir, keypoints_dir, device=device, num_edges=num_edges, normalize=True)
train_set_size = int(0.8 * len(dataset))
val_set_size = int(0.1 * len(dataset))
test_set_size = len(dataset) - train_set_size - val_set_size
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_set_size, val_set_size, test_set_size])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

In [5]:
keypoints_predictor = KeypointPredictionNetwork(input_channels=input_channels, num_keypoints=num_keypoints).to(device)
keypoints_predictor.load_state_dict(torch.load(model_save_dir + 'keypoints_predictor.pth', weights_only=True))

optimizer = optim.Adam(keypoints_predictor.parameters(), lr=learning_rate)

chamfer_loss = ChamferLoss()
sum_of_distances_loss = SumOfDistancesLoss()
hungarian_sum_of_distances_loss = HungarianSumOfDistancesLoss()

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)
scaler = GradScaler()

In [6]:
train(keypoints_predictor, optimizer, hungarian_sum_of_distances_loss, scaler, scheduler, train_loader, valid_loader, num_epochs, device, model_save_dir)

Epoch 1/90


 - Training: 100%|██████████| 8/8 [01:38<00:00, 12.32s/it]
 - Validation: 100%|██████████| 1/1 [00:12<00:00, 12.04s/it]


 - Train Loss: 1.2037423104047775 - Valid Loss: 1.1974538564682007 - Learning Rate: 0.001
Epoch 2/90


 - Training: 100%|██████████| 8/8 [01:40<00:00, 12.60s/it]
 - Validation: 100%|██████████| 1/1 [00:12<00:00, 12.50s/it]


 - Train Loss: 1.078750118613243 - Valid Loss: 1.1117513179779053 - Learning Rate: 0.001
Epoch 3/90


 - Training: 100%|██████████| 8/8 [01:31<00:00, 11.49s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.84s/it]


 - Train Loss: 1.0421992242336273 - Valid Loss: 1.0834524631500244 - Learning Rate: 0.001
Epoch 4/90


 - Training: 100%|██████████| 8/8 [01:27<00:00, 10.92s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.77s/it]


 - Train Loss: 1.0041489377617836 - Valid Loss: 1.0533183813095093 - Learning Rate: 0.001
Epoch 5/90


 - Training: 100%|██████████| 8/8 [01:27<00:00, 10.90s/it]
 - Validation: 100%|██████████| 1/1 [00:11<00:00, 11.65s/it]


 - Train Loss: 0.9928848370909691 - Valid Loss: 1.049733281135559 - Learning Rate: 0.001
Epoch 6/90


 - Training: 100%|██████████| 8/8 [01:28<00:00, 11.07s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.78s/it]


 - Train Loss: 0.9922189638018608 - Valid Loss: 1.052846074104309 - Learning Rate: 0.001
Epoch 7/90


 - Training: 100%|██████████| 8/8 [01:28<00:00, 11.09s/it]
 - Validation: 100%|██████████| 1/1 [00:11<00:00, 11.06s/it]


 - Train Loss: 0.9485499933362007 - Valid Loss: 1.0279960632324219 - Learning Rate: 0.001
Epoch 8/90


 - Training: 100%|██████████| 8/8 [01:24<00:00, 10.61s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.82s/it]


 - Train Loss: 0.9404351338744164 - Valid Loss: 1.009911060333252 - Learning Rate: 0.001
Epoch 9/90


 - Training: 100%|██████████| 8/8 [01:20<00:00, 10.03s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.14s/it]


 - Train Loss: 0.9225394427776337 - Valid Loss: 1.0018820762634277 - Learning Rate: 0.001
Epoch 10/90


 - Training: 100%|██████████| 8/8 [01:21<00:00, 10.22s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.21s/it]


 - Train Loss: 0.9222111776471138 - Valid Loss: 0.9781063795089722 - Learning Rate: 0.001
Epoch 11/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.64s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.82s/it]


 - Train Loss: 0.8770388439297676 - Valid Loss: 0.957394540309906 - Learning Rate: 0.001
Epoch 12/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.68s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.54s/it]


 - Train Loss: 0.8599269017577171 - Valid Loss: 0.9626719951629639 - Learning Rate: 0.001
Epoch 13/90


 - Training: 100%|██████████| 8/8 [01:26<00:00, 10.78s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.74s/it]


 - Train Loss: 0.8474657982587814 - Valid Loss: 0.9571987390518188 - Learning Rate: 0.001
Epoch 14/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.69s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.75s/it]


 - Train Loss: 0.844787172973156 - Valid Loss: 0.9260306358337402 - Learning Rate: 0.001
Epoch 15/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.69s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.58s/it]


 - Train Loss: 0.8187152296304703 - Valid Loss: 0.919293999671936 - Learning Rate: 0.001
Epoch 16/90


 - Training: 100%|██████████| 8/8 [01:26<00:00, 10.78s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.73s/it]


 - Train Loss: 0.8042222261428833 - Valid Loss: 0.9185143113136292 - Learning Rate: 0.001
Epoch 17/90


 - Training: 100%|██████████| 8/8 [01:24<00:00, 10.55s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.97s/it]


 - Train Loss: 0.8124155551195145 - Valid Loss: 0.9140294194221497 - Learning Rate: 0.001
Epoch 18/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.92s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.86s/it]


 - Train Loss: 0.7866013273596764 - Valid Loss: 0.904214084148407 - Learning Rate: 0.001
Epoch 19/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.90s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.89s/it]


 - Train Loss: 0.7762119770050049 - Valid Loss: 0.8690861463546753 - Learning Rate: 0.001
Epoch 20/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.09s/it]


 - Train Loss: 0.7395830005407333 - Valid Loss: 0.8445146679878235 - Learning Rate: 0.001
Epoch 21/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.85s/it]


 - Train Loss: 0.724909745156765 - Valid Loss: 0.8581646084785461 - Learning Rate: 0.001
Epoch 22/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.92s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.94s/it]


 - Train Loss: 0.7086169347167015 - Valid Loss: 0.857692539691925 - Learning Rate: 0.001
Epoch 23/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.6955022886395454 - Valid Loss: 0.8673642873764038 - Learning Rate: 0.001
Epoch 24/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.6986258327960968 - Valid Loss: 0.8405507802963257 - Learning Rate: 0.001
Epoch 25/90


 - Training: 100%|██████████| 8/8 [01:20<00:00, 10.02s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00, 10.00s/it]


 - Train Loss: 0.7230284959077835 - Valid Loss: 0.8624316453933716 - Learning Rate: 0.001
Epoch 26/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.6772418394684792 - Valid Loss: 0.832781970500946 - Learning Rate: 0.001
Epoch 27/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.6891056522727013 - Valid Loss: 0.8119814395904541 - Learning Rate: 0.001
Epoch 28/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.90s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.81s/it]


 - Train Loss: 0.6654737889766693 - Valid Loss: 0.840646505355835 - Learning Rate: 0.001
Epoch 29/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.98s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.94s/it]


 - Train Loss: 0.639567669481039 - Valid Loss: 0.803973913192749 - Learning Rate: 0.001
Epoch 30/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.96s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.6283758357167244 - Valid Loss: 0.8079999685287476 - Learning Rate: 0.001
Epoch 31/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.02s/it]


 - Train Loss: 0.6148069202899933 - Valid Loss: 0.8159990906715393 - Learning Rate: 0.001
Epoch 32/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.95s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.88s/it]


 - Train Loss: 0.6146356239914894 - Valid Loss: 0.78824383020401 - Learning Rate: 0.001
Epoch 33/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.86s/it]


 - Train Loss: 0.6087429597973824 - Valid Loss: 0.8168215751647949 - Learning Rate: 0.001
Epoch 34/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.92s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.28s/it]


 - Train Loss: 0.6036015227437019 - Valid Loss: 0.808120608329773 - Learning Rate: 0.001
Epoch 35/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.86s/it]


 - Train Loss: 0.6004839017987251 - Valid Loss: 0.7927780151367188 - Learning Rate: 0.001
Epoch 36/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.5820300839841366 - Valid Loss: 0.7894269824028015 - Learning Rate: 0.001
Epoch 37/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.88s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.81s/it]


 - Train Loss: 0.5902483463287354 - Valid Loss: 0.7948395013809204 - Learning Rate: 0.001
Epoch 38/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.72s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.61s/it]


 - Train Loss: 0.5691095814108849 - Valid Loss: 0.7960222363471985 - Learning Rate: 0.001
Epoch 39/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.63s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.60s/it]


 - Train Loss: 0.5713254064321518 - Valid Loss: 0.7823653221130371 - Learning Rate: 0.001
Epoch 40/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.64s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.78s/it]


 - Train Loss: 0.5466000065207481 - Valid Loss: 0.7859560251235962 - Learning Rate: 0.001
Epoch 41/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.73s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.62s/it]


 - Train Loss: 0.5508902296423912 - Valid Loss: 0.7644219994544983 - Learning Rate: 0.001
Epoch 42/90


 - Training: 100%|██████████| 8/8 [01:27<00:00, 10.91s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.74s/it]


 - Train Loss: 0.5283750854432583 - Valid Loss: 0.7926411032676697 - Learning Rate: 0.001
Epoch 43/90


 - Training: 100%|██████████| 8/8 [01:26<00:00, 10.84s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.78s/it]


 - Train Loss: 0.5358378998935223 - Valid Loss: 0.7760211229324341 - Learning Rate: 0.001
Epoch 44/90


 - Training: 100%|██████████| 8/8 [01:25<00:00, 10.69s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.94s/it]


 - Train Loss: 0.5377487242221832 - Valid Loss: 0.7884353995323181 - Learning Rate: 0.001
Epoch 45/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.90s/it]


 - Train Loss: 0.5122401602566242 - Valid Loss: 0.7672119140625 - Learning Rate: 0.001
Epoch 46/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.87s/it]


 - Train Loss: 0.49991942569613457 - Valid Loss: 0.7550032734870911 - Learning Rate: 0.001
Epoch 47/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.49096501618623734 - Valid Loss: 0.760787844657898 - Learning Rate: 0.001
Epoch 48/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.96s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.5066098906099796 - Valid Loss: 0.7730379700660706 - Learning Rate: 0.001
Epoch 49/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.03s/it]


 - Train Loss: 0.47746488079428673 - Valid Loss: 0.7701312303543091 - Learning Rate: 0.001
Epoch 50/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.96s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.84s/it]


 - Train Loss: 0.5006338730454445 - Valid Loss: 0.7896760702133179 - Learning Rate: 0.001
Epoch 51/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.82s/it]


 - Train Loss: 0.4935762584209442 - Valid Loss: 0.7607199549674988 - Learning Rate: 0.001
Epoch 52/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.97s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.4859509579837322 - Valid Loss: 0.7659782767295837 - Learning Rate: 0.001
Epoch 53/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.96s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.93s/it]


 - Train Loss: 0.47427939996123314 - Valid Loss: 0.7671984434127808 - Learning Rate: 0.001
Epoch 54/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.95s/it]


 - Train Loss: 0.46105340123176575 - Valid Loss: 0.7536234855651855 - Learning Rate: 0.001
Epoch 55/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.92s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.80s/it]


 - Train Loss: 0.45985571295022964 - Valid Loss: 0.7581197619438171 - Learning Rate: 0.001
Epoch 56/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.16s/it]


 - Train Loss: 0.47676269337534904 - Valid Loss: 0.7800440192222595 - Learning Rate: 0.001
Epoch 57/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.85s/it]


 - Train Loss: 0.46702102571725845 - Valid Loss: 0.7614597678184509 - Learning Rate: 0.001
Epoch 58/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.89s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.95s/it]


 - Train Loss: 0.4586131162941456 - Valid Loss: 0.7494335174560547 - Learning Rate: 0.001
Epoch 59/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.98s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.84s/it]


 - Train Loss: 0.4481998458504677 - Valid Loss: 0.7406191825866699 - Learning Rate: 0.001
Epoch 60/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.4505455158650875 - Valid Loss: 0.7686368227005005 - Learning Rate: 0.001
Epoch 61/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.96s/it]


 - Train Loss: 0.4491770975291729 - Valid Loss: 0.7543013095855713 - Learning Rate: 0.001
Epoch 62/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.99s/it]


 - Train Loss: 0.44315047934651375 - Valid Loss: 0.7444552183151245 - Learning Rate: 0.001
Epoch 63/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.96s/it]


 - Train Loss: 0.4267473891377449 - Valid Loss: 0.7571097612380981 - Learning Rate: 0.001
Epoch 64/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.84s/it]


 - Train Loss: 0.433060459792614 - Valid Loss: 0.7649828791618347 - Learning Rate: 0.001
Epoch 65/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.82s/it]


 - Train Loss: 0.4385242611169815 - Valid Loss: 0.7544180750846863 - Learning Rate: 0.001
Epoch 66/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.95s/it]


 - Train Loss: 0.4267318770289421 - Valid Loss: 0.7397533655166626 - Learning Rate: 0.001
Epoch 67/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.85s/it]


 - Train Loss: 0.41632179915905 - Valid Loss: 0.7239422798156738 - Learning Rate: 0.001
Epoch 68/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.95s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.98s/it]


 - Train Loss: 0.41502850130200386 - Valid Loss: 0.7399117946624756 - Learning Rate: 0.001
Epoch 69/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.86s/it]


 - Train Loss: 0.4167853966355324 - Valid Loss: 0.7435340285301208 - Learning Rate: 0.001
Epoch 70/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.94s/it]


 - Train Loss: 0.4113617576658726 - Valid Loss: 0.7558354139328003 - Learning Rate: 0.001
Epoch 71/90


 - Training: 100%|██████████| 8/8 [01:20<00:00, 10.02s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.82s/it]


 - Train Loss: 0.4136025458574295 - Valid Loss: 0.7482450008392334 - Learning Rate: 0.001
Epoch 72/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.97s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.95s/it]


 - Train Loss: 0.40756820887327194 - Valid Loss: 0.7570326328277588 - Learning Rate: 0.001
Epoch 73/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.89s/it]


 - Train Loss: 0.4210543632507324 - Valid Loss: 0.7625183463096619 - Learning Rate: 0.001
Epoch 74/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.89s/it]


 - Train Loss: 0.40614110231399536 - Valid Loss: 0.7452932000160217 - Learning Rate: 0.001
Epoch 75/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.99s/it]


 - Train Loss: 0.4052892029285431 - Valid Loss: 0.7275116443634033 - Learning Rate: 0.001
Epoch 76/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.95s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.3885125331580639 - Valid Loss: 0.7319245934486389 - Learning Rate: 0.001
Epoch 77/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.92s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.03s/it]


 - Train Loss: 0.38197940960526466 - Valid Loss: 0.7425163388252258 - Learning Rate: 0.001
Epoch 78/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.87s/it]


 - Train Loss: 0.36722124367952347 - Valid Loss: 0.7501716613769531 - Learning Rate: 0.001
Epoch 79/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.96s/it]


 - Train Loss: 0.37918340787291527 - Valid Loss: 0.731433629989624 - Learning Rate: 0.001
Epoch 80/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.95s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.14s/it]


 - Train Loss: 0.3790147490799427 - Valid Loss: 0.7372281551361084 - Learning Rate: 0.001
Epoch 81/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.77s/it]


 - Train Loss: 0.3721465654671192 - Valid Loss: 0.7490694522857666 - Learning Rate: 0.001
Epoch 82/90


 - Training: 100%|██████████| 8/8 [01:20<00:00, 10.03s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.07s/it]


 - Train Loss: 0.377389308065176 - Valid Loss: 0.7401422262191772 - Learning Rate: 0.0002
Epoch 83/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.3677656874060631 - Valid Loss: 0.7224131226539612 - Learning Rate: 0.0002
Epoch 84/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.95s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.03s/it]


 - Train Loss: 0.3301656059920788 - Valid Loss: 0.722255527973175 - Learning Rate: 0.0002
Epoch 85/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.99s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.88s/it]


 - Train Loss: 0.32306862995028496 - Valid Loss: 0.7236632108688354 - Learning Rate: 0.0002
Epoch 86/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.91s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.90s/it]


 - Train Loss: 0.3206007666885853 - Valid Loss: 0.7177476286888123 - Learning Rate: 0.0002
Epoch 87/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.94s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.88s/it]


 - Train Loss: 0.31806263886392117 - Valid Loss: 0.7144190073013306 - Learning Rate: 0.0002
Epoch 88/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.90s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.83s/it]


 - Train Loss: 0.2924596667289734 - Valid Loss: 0.7143061757087708 - Learning Rate: 0.0002
Epoch 89/90


 - Training: 100%|██████████| 8/8 [01:19<00:00,  9.93s/it]
 - Validation: 100%|██████████| 1/1 [00:09<00:00,  9.89s/it]


 - Train Loss: 0.3084025904536247 - Valid Loss: 0.7145311832427979 - Learning Rate: 0.0002
Epoch 90/90


 - Training: 100%|██████████| 8/8 [01:20<00:00, 10.02s/it]
 - Validation: 100%|██████████| 1/1 [00:10<00:00, 10.96s/it]

 - Train Loss: 0.2903407607227564 - Valid Loss: 0.715191662311554 - Learning Rate: 0.0002





In [7]:
keypoints_predictor_test = KeypointPredictionNetwork(input_channels=input_channels, num_keypoints=num_keypoints).to(device)
keypoints_predictor_test.load_state_dict(torch.load(model_save_dir + 'keypoints_predictor.pth', weights_only=True))
test(keypoints_predictor_test, test_loader, hungarian_sum_of_distances_loss, device)

Testing: 100%|██████████| 1/1 [00:11<00:00, 11.40s/it]

Test Loss: 0.6218799948692322





In [22]:
mesh, edge_features, keypoints = test_set[14]
keypoints = keypoints.cpu().detach().numpy()
predicted_keypoints = keypoints_predictor(edge_features.unsqueeze(0).to(torch.float32).to(device)).squeeze().cpu().detach().numpy()


fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))


fig.add_trace(go.Mesh3d(x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2], i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2], color='lightgrey', opacity=0.5))

for i, keypoint in enumerate(predicted_keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=5, color='blue')))

for i, keypoint in enumerate(keypoints):
    fig.add_trace(go.Scatter3d(x=[keypoint[0]], y=[keypoint[1]], z=[keypoint[2]], mode='markers', marker=dict(size=3, color='red')))

fig.show()