In [1]:
import torch 
from torch.utils.data import DataLoader
from src import full_pcd_dataset
import importlib
importlib.reload(full_pcd_dataset)
import json
# Create dataset instance
dataset = full_pcd_dataset.FullPCDDataset("full_pcd_30000_samples.npz")
with open("label_dict.json", "r") as f:
    label_dict = json.load(f)
num_classes = len(label_dict.keys())
# Split into train and validation (example: 80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16,pin_memory=True)
num_classes

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


5

In [2]:
train_dataset[0][1][0]

tensor(2)

In [3]:
from src.model import PoseWithClassModel
from torch.optim.lr_scheduler import StepLR
model = PoseWithClassModel(num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=5, gamma=0.75) 


In [4]:
from src.train_utils import train_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.is_available()
torch.backends.cudnn.version()
torch.cuda.is_available()
torch.cuda.get_device_properties(0)
print(f"Using device: {device}")


Using device: cuda


In [5]:
torch.backends.cudnn.is_available()

True

In [6]:
torch.backends.cudnn.version()

90100

In [7]:
torch.cuda.is_available()

True

In [8]:
torch.cuda.get_device_properties(0)

_CudaDeviceProperties(name='NVIDIA GeForce RTX 4070', major=8, minor=9, total_memory=12281MB, multi_processor_count=46, uuid=9e60166a-e60f-451d-3cb0-2063affabbf5, L2_cache_size=36MB)

In [9]:
train_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=100)

Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 1/100 - Total Loss: 341.8648 - Cls: 232.6820 - Pose: 129.0958


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: 2.3245 | Cls Loss: 86.4706 | Pose Loss: 35.6950 | Accuracy: 43.28%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 2/100 - Total Loss: 138.3232 - Cls: 93.3445 - Pose: 147.1104


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: 0.5113 | Cls Loss: 24.0987 | Pose Loss: 38.3579 | Accuracy: 78.17%


Training: 100%|██████████| 188/188 [03:27<00:00,  1.11s/it]


Epoch 3/100 - Total Loss: 38.6380 - Cls: 76.8314 - Pose: 165.0941


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -0.0160 | Cls Loss: 16.7361 | Pose Loss: 45.1198 | Accuracy: 88.12%


Training: 100%|██████████| 188/188 [03:23<00:00,  1.08s/it]


Epoch 4/100 - Total Loss: -40.5158 - Cls: 71.8711 - Pose: 182.5490


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: 0.3786 | Cls Loss: 45.7234 | Pose Loss: 56.7163 | Accuracy: 72.82%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 5/100 - Total Loss: -140.1259 - Cls: 49.1567 - Pose: 191.2418


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -1.0439 | Cls Loss: 8.6559 | Pose Loss: 48.4720 | Accuracy: 96.93%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 6/100 - Total Loss: -220.5969 - Cls: 47.8953 - Pose: 197.6656


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -1.2575 | Cls Loss: 14.8519 | Pose Loss: 52.6350 | Accuracy: 95.48%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 7/100 - Total Loss: -297.5972 - Cls: 48.1143 - Pose: 199.7748


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -1.7287 | Cls Loss: 13.7210 | Pose Loss: 51.2066 | Accuracy: 96.58%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 8/100 - Total Loss: -375.5608 - Cls: 45.0207 - Pose: 201.3974


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -2.1225 | Cls Loss: 15.4061 | Pose Loss: 50.7006 | Accuracy: 96.55%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 9/100 - Total Loss: -466.7382 - Cls: 35.7524 - Pose: 200.6284


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -2.6891 | Cls Loss: 9.9639 | Pose Loss: 49.6020 | Accuracy: 98.57%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 10/100 - Total Loss: -527.7989 - Cls: 44.9511 - Pose: 207.6323


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -2.7544 | Cls Loss: 14.3814 | Pose Loss: 60.7119 | Accuracy: 98.07%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 11/100 - Total Loss: -569.5696 - Cls: 65.9389 - Pose: 214.4158


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -3.3354 | Cls Loss: 11.0472 | Pose Loss: 52.7963 | Accuracy: 98.87%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 12/100 - Total Loss: -623.8526 - Cls: 74.7163 - Pose: 216.3327


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.1131 | Cls Loss: 35.6003 | Pose Loss: 54.3326 | Accuracy: 96.43%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 13/100 - Total Loss: -732.7850 - Cls: 32.7417 - Pose: 214.9718


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -4.1088 | Cls Loss: 5.7027 | Pose Loss: 56.4634 | Accuracy: 99.68%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 14/100 - Total Loss: -750.6712 - Cls: 76.6754 - Pose: 226.6521


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -2.9127 | Cls Loss: 69.5783 | Pose Loss: 63.6421 | Accuracy: 96.43%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 15/100 - Total Loss: -772.0130 - Cls: 100.9812 - Pose: 235.3810


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -4.2483 | Cls Loss: 21.6551 | Pose Loss: 61.6431 | Accuracy: 98.78%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 16/100 - Total Loss: -920.6128 - Cls: 23.3246 - Pose: 225.6432


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -5.1421 | Cls Loss: 4.8717 | Pose Loss: 55.4698 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 17/100 - Total Loss: -918.9978 - Cls: 73.5497 - Pose: 240.7497


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -5.2919 | Cls Loss: 6.3765 | Pose Loss: 61.0891 | Accuracy: 99.75%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 18/100 - Total Loss: -993.8244 - Cls: 54.0496 - Pose: 251.0754


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: 14.7567 | Cls Loss: 937.1408 | Pose Loss: 87.8426 | Accuracy: 86.63%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 19/100 - Total Loss: -1027.9007 - Cls: 66.1551 - Pose: 256.6985


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -6.0814 | Cls Loss: 6.0864 | Pose Loss: 53.8260 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 20/100 - Total Loss: -1017.8511 - Cls: 112.7566 - Pose: 270.5049


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -5.9446 | Cls Loss: 4.5813 | Pose Loss: 72.4999 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 21/100 - Total Loss: -1176.8491 - Cls: 24.8159 - Pose: 257.0816


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -6.6643 | Cls Loss: 2.0336 | Pose Loss: 58.0826 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 22/100 - Total Loss: -1061.7917 - Cls: 161.0589 - Pose: 284.8808


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -6.5294 | Cls Loss: 3.2209 | Pose Loss: 69.4995 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 23/100 - Total Loss: -1192.2041 - Cls: 61.8846 - Pose: 292.1617


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -5.6751 | Cls Loss: 29.1486 | Pose Loss: 96.0888 | Accuracy: 99.42%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 24/100 - Total Loss: -1255.7806 - Cls: 46.1567 - Pose: 290.4978


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.11it/s]


Validation - Loss: -6.9336 | Cls Loss: 10.3467 | Pose Loss: 69.3581 | Accuracy: 99.88%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 25/100 - Total Loss: -1250.8889 - Cls: 88.4812 - Pose: 307.0922


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.11it/s]


Validation - Loss: -3.6435 | Cls Loss: 138.5331 | Pose Loss: 106.8435 | Accuracy: 98.02%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 26/100 - Total Loss: -1061.8350 - Cls: 230.7695 - Pose: 356.0369


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.3278 | Cls Loss: 3.8001 | Pose Loss: 66.4828 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 27/100 - Total Loss: -1347.8794 - Cls: 19.9324 - Pose: 313.5872


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.4986 | Cls Loss: 5.6604 | Pose Loss: 68.1178 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 28/100 - Total Loss: -1294.3629 - Cls: 101.2048 - Pose: 328.5195


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: 0.7992 | Cls Loss: 248.3968 | Pose Loss: 222.6892 | Accuracy: 96.17%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 29/100 - Total Loss: -1243.3084 - Cls: 120.2570 - Pose: 361.9075


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.4246 | Cls Loss: 8.9316 | Pose Loss: 76.5451 | Accuracy: 99.90%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 30/100 - Total Loss: -1397.2727 - Cls: 21.9970 - Pose: 335.9230


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -6.7376 | Cls Loss: 17.0858 | Pose Loss: 109.4979 | Accuracy: 99.83%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 31/100 - Total Loss: -1337.6768 - Cls: 84.3530 - Pose: 355.5448


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.09it/s]


Validation - Loss: -7.7923 | Cls Loss: 5.1110 | Pose Loss: 76.3496 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 32/100 - Total Loss: -1266.7340 - Cls: 172.4303 - Pose: 360.4932


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.5566 | Cls Loss: 15.5640 | Pose Loss: 78.3790 | Accuracy: 99.82%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 33/100 - Total Loss: -1445.2091 - Cls: 33.4055 - Pose: 335.8844


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.8275 | Cls Loss: 5.3318 | Pose Loss: 85.5147 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 34/100 - Total Loss: -1391.7606 - Cls: 81.7400 - Pose: 368.2607


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -3.3799 | Cls Loss: 202.9294 | Pose Loss: 101.1384 | Accuracy: 98.37%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 35/100 - Total Loss: -1305.5790 - Cls: 199.8413 - Pose: 348.5791


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -8.6317 | Cls Loss: 4.5125 | Pose Loss: 53.7812 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 36/100 - Total Loss: -1417.4050 - Cls: 104.0200 - Pose: 351.2185


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -7.5304 | Cls Loss: 4.9339 | Pose Loss: 110.9463 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 37/100 - Total Loss: -1458.2973 - Cls: 78.4976 - Pose: 350.6789


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.5724 | Cls Loss: 5.8555 | Pose Loss: 66.8115 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 38/100 - Total Loss: -1513.4732 - Cls: 25.5244 - Pose: 377.2277


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.1568 | Cls Loss: 4.1481 | Pose Loss: 94.3612 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 39/100 - Total Loss: -1278.8292 - Cls: 277.0868 - Pose: 373.3331


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -8.0926 | Cls Loss: 25.6708 | Pose Loss: 72.1765 | Accuracy: 99.83%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 40/100 - Total Loss: -1392.3379 - Cls: 129.0022 - Pose: 395.1087


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.5067 | Cls Loss: 9.8700 | Pose Loss: 70.1975 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 41/100 - Total Loss: -1478.6488 - Cls: 83.8596 - Pose: 363.8336


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.0354 | Cls Loss: 10.2654 | Pose Loss: 95.6154 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 42/100 - Total Loss: -1517.8692 - Cls: 71.2173 - Pose: 352.8933


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.0703 | Cls Loss: 0.7073 | Pose Loss: 62.4259 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 43/100 - Total Loss: -1562.1737 - Cls: 53.7487 - Pose: 357.6862


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -6.3635 | Cls Loss: 85.6157 | Pose Loss: 110.6699 | Accuracy: 99.62%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 44/100 - Total Loss: -1330.0444 - Cls: 256.5711 - Pose: 380.4654


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -8.9331 | Cls Loss: 5.3990 | Pose Loss: 67.9661 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 45/100 - Total Loss: -1609.3439 - Cls: 26.6476 - Pose: 350.6333


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.2131 | Cls Loss: 2.8930 | Pose Loss: 63.6944 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 46/100 - Total Loss: -1540.6009 - Cls: 80.7252 - Pose: 380.4918


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -9.3323 | Cls Loss: 3.2404 | Pose Loss: 60.3008 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 47/100 - Total Loss: -1636.7134 - Cls: 18.2756 - Pose: 365.9967


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.6117 | Cls Loss: 21.4370 | Pose Loss: 128.4990 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 48/100 - Total Loss: -632.6840 - Cls: 908.7077 - Pose: 452.9524


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.5315 | Cls Loss: 17.0334 | Pose Loss: 71.2319 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 49/100 - Total Loss: -1574.8787 - Cls: 56.0208 - Pose: 337.5521


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -8.8279 | Cls Loss: 3.8132 | Pose Loss: 76.3441 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 50/100 - Total Loss: -1616.7071 - Cls: 28.9142 - Pose: 346.6532


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.7091 | Cls Loss: 5.4115 | Pose Loss: 85.1480 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 51/100 - Total Loss: -1558.5735 - Cls: 89.7849 - Pose: 359.9162


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -7.6604 | Cls Loss: 26.1245 | Pose Loss: 116.5676 | Accuracy: 99.90%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 52/100 - Total Loss: -1435.1318 - Cls: 195.9697 - Pose: 377.8373


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.4180 | Cls Loss: 3.0757 | Pose Loss: 56.3522 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 53/100 - Total Loss: -1641.7622 - Cls: 28.4106 - Pose: 348.9871


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.4579 | Cls Loss: 3.0313 | Pose Loss: 59.0749 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 54/100 - Total Loss: -1664.6861 - Cls: 18.9449 - Pose: 355.1251


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.3810 | Cls Loss: 6.8234 | Pose Loss: 63.7808 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 55/100 - Total Loss: -1490.8794 - Cls: 164.9955 - Pose: 386.5219


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -8.0846 | Cls Loss: 19.5695 | Pose Loss: 111.1141 | Accuracy: 99.88%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 56/100 - Total Loss: -1370.7619 - Cls: 257.5146 - Pose: 404.7060


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.6736 | Cls Loss: 17.8200 | Pose Loss: 81.1511 | Accuracy: 99.90%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 57/100 - Total Loss: -1571.9800 - Cls: 97.2440 - Pose: 362.2624


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.09it/s]


Validation - Loss: -9.3019 | Cls Loss: 8.1494 | Pose Loss: 64.7366 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 58/100 - Total Loss: -1635.3728 - Cls: 37.1125 - Pose: 372.5102


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.09it/s]


Validation - Loss: -7.6848 | Cls Loss: 87.7311 | Pose Loss: 64.1320 | Accuracy: 99.63%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 59/100 - Total Loss: -1496.5661 - Cls: 186.2161 - Pose: 374.4478


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -8.4081 | Cls Loss: 16.5246 | Pose Loss: 102.3535 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 60/100 - Total Loss: -1691.8628 - Cls: 30.8187 - Pose: 342.2365


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.4929 | Cls Loss: 8.0521 | Pose Loss: 64.6428 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 61/100 - Total Loss: -1689.5050 - Cls: 24.8225 - Pose: 370.6864


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.5393 | Cls Loss: 13.1836 | Pose Loss: 61.5783 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 62/100 - Total Loss: -1563.8311 - Cls: 159.0208 - Pose: 371.0948


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -9.4002 | Cls Loss: 9.7101 | Pose Loss: 72.5059 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 63/100 - Total Loss: -1723.9190 - Cls: 28.2392 - Pose: 351.5475


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -9.8230 | Cls Loss: 2.4359 | Pose Loss: 64.8616 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 64/100 - Total Loss: -1761.9226 - Cls: 2.8923 - Pose: 363.9457


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -9.5854 | Cls Loss: 0.4954 | Pose Loss: 83.3494 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 65/100 - Total Loss: -1243.0578 - Cls: 439.2330 - Pose: 452.1172


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -8.7049 | Cls Loss: 15.5913 | Pose Loss: 98.2979 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 66/100 - Total Loss: -1456.9521 - Cls: 239.5652 - Pose: 387.1808


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.09it/s]


Validation - Loss: -9.7635 | Cls Loss: 1.3042 | Pose Loss: 62.3922 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 67/100 - Total Loss: -1713.9932 - Cls: 29.0305 - Pose: 357.7446


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -9.4596 | Cls Loss: 5.1653 | Pose Loss: 77.4468 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 68/100 - Total Loss: -1746.1037 - Cls: 16.9154 - Pose: 356.4617


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -10.0288 | Cls Loss: 0.5942 | Pose Loss: 59.9357 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 69/100 - Total Loss: -1549.4099 - Cls: 193.3936 - Pose: 385.8146


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -7.9349 | Cls Loss: 78.4294 | Pose Loss: 78.9157 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 70/100 - Total Loss: -1522.8680 - Cls: 214.0964 - Pose: 378.5658


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.09it/s]


Validation - Loss: -9.5325 | Cls Loss: 7.3859 | Pose Loss: 75.1572 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 71/100 - Total Loss: -1767.5536 - Cls: 7.0308 - Pose: 355.3685


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -10.0903 | Cls Loss: 0.6994 | Pose Loss: 60.6566 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 72/100 - Total Loss: -1759.9366 - Cls: 23.3910 - Pose: 365.8384


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -10.1187 | Cls Loss: 2.7396 | Pose Loss: 61.0680 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 73/100 - Total Loss: -1774.3725 - Cls: 15.4957 - Pose: 372.4686


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -10.1741 | Cls Loss: 0.3007 | Pose Loss: 64.9237 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 74/100 - Total Loss: -1803.5432 - Cls: 20.1691 - Pose: 361.2668


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -9.8803 | Cls Loss: 0.4048 | Pose Loss: 83.9424 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 75/100 - Total Loss: -1844.4776 - Cls: 4.1703 - Pose: 355.5427


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -10.2199 | Cls Loss: 4.4339 | Pose Loss: 70.0030 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 76/100 - Total Loss: -1843.5420 - Cls: 2.3300 - Pose: 381.7605


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.10it/s]


Validation - Loss: -10.3356 | Cls Loss: 0.5406 | Pose Loss: 72.6461 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:29<00:00,  1.12s/it]


Epoch 77/100 - Total Loss: -352.8233 - Cls: 1395.9220 - Pose: 468.3674


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: 2.6469 | Cls Loss: 422.3008 | Pose Loss: 241.5537 | Accuracy: 99.07%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 78/100 - Total Loss: -1660.6555 - Cls: 129.8892 - Pose: 369.2439


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -9.7353 | Cls Loss: 5.9460 | Pose Loss: 78.0146 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 79/100 - Total Loss: -1797.2927 - Cls: 33.2033 - Pose: 343.0989


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -10.2481 | Cls Loss: 0.7831 | Pose Loss: 62.7026 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 80/100 - Total Loss: -1804.3317 - Cls: 22.2029 - Pose: 361.5443


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -10.0338 | Cls Loss: 2.4359 | Pose Loss: 74.2777 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 81/100 - Total Loss: -1783.5601 - Cls: 47.6798 - Pose: 365.3470


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.07it/s]


Validation - Loss: -10.1117 | Cls Loss: 9.3181 | Pose Loss: 66.0489 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.06s/it]


Epoch 82/100 - Total Loss: -1823.5510 - Cls: 22.3806 - Pose: 362.0114


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -10.3002 | Cls Loss: 0.2640 | Pose Loss: 68.5680 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 83/100 - Total Loss: -1860.3063 - Cls: 8.8319 - Pose: 351.3562


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.09it/s]


Validation - Loss: -10.1897 | Cls Loss: 1.5936 | Pose Loss: 75.6689 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 84/100 - Total Loss: -1705.1949 - Cls: 163.7263 - Pose: 363.0420


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -7.7143 | Cls Loss: 121.4603 | Pose Loss: 73.5018 | Accuracy: 99.77%


Training: 100%|██████████| 188/188 [03:25<00:00,  1.09s/it]


Epoch 85/100 - Total Loss: -1804.3644 - Cls: 55.8970 - Pose: 372.1948


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -10.6231 | Cls Loss: 0.4565 | Pose Loss: 60.0735 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:23<00:00,  1.08s/it]


Epoch 86/100 - Total Loss: -1752.3001 - Cls: 125.5601 - Pose: 369.3586


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: 6.1839 | Cls Loss: 710.4048 | Pose Loss: 140.7241 | Accuracy: 98.47%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 87/100 - Total Loss: -1822.2771 - Cls: 31.3103 - Pose: 390.6129


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -10.4420 | Cls Loss: 0.7625 | Pose Loss: 70.0887 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:28<00:00,  1.11s/it]


Epoch 88/100 - Total Loss: -1813.7651 - Cls: 77.2045 - Pose: 362.9069


Validating: 100%|██████████| 47/47 [00:47<00:00,  1.01s/it]


Validation - Loss: -10.7388 | Cls Loss: 1.0292 | Pose Loss: 58.8014 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:28<00:00,  1.11s/it]


Epoch 89/100 - Total Loss: -1707.1692 - Cls: 180.4388 - Pose: 375.4852


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -10.0029 | Cls Loss: 23.3510 | Pose Loss: 71.3834 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:25<00:00,  1.09s/it]


Epoch 90/100 - Total Loss: -1785.1115 - Cls: 111.7653 - Pose: 368.5080


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.11it/s]


Validation - Loss: -10.0510 | Cls Loss: 18.5958 | Pose Loss: 75.8544 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 91/100 - Total Loss: -1623.4554 - Cls: 259.4008 - Pose: 382.0358


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.11it/s]


Validation - Loss: -3.8590 | Cls Loss: 293.9946 | Pose Loss: 89.3080 | Accuracy: 99.43%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 92/100 - Total Loss: -1454.1601 - Cls: 394.5098 - Pose: 396.7217


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -9.6421 | Cls Loss: 43.5219 | Pose Loss: 63.3717 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 93/100 - Total Loss: -1740.8448 - Cls: 111.2461 - Pose: 385.8374


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -10.5161 | Cls Loss: 0.7352 | Pose Loss: 65.8621 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:19<00:00,  1.06s/it]


Epoch 94/100 - Total Loss: -1776.4163 - Cls: 100.2409 - Pose: 369.1258


Validating: 100%|██████████| 47/47 [00:43<00:00,  1.08it/s]


Validation - Loss: -9.2885 | Cls Loss: 37.0051 | Pose Loss: 88.6263 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 95/100 - Total Loss: -733.3520 - Cls: 1043.9932 - Pose: 438.1289


Validating: 100%|██████████| 47/47 [00:42<00:00,  1.11it/s]


Validation - Loss: -10.4332 | Cls Loss: 3.1225 | Pose Loss: 58.3588 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 96/100 - Total Loss: -1824.2753 - Cls: 39.0550 - Pose: 351.3603


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -10.5461 | Cls Loss: 2.7998 | Pose Loss: 56.8969 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 97/100 - Total Loss: -1742.8379 - Cls: 106.0771 - Pose: 375.8960


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.3166 | Cls Loss: 21.8193 | Pose Loss: 96.1486 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 98/100 - Total Loss: -1874.2763 - Cls: 5.3767 - Pose: 348.1301


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -10.5548 | Cls Loss: 1.9766 | Pose Loss: 61.7440 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:26<00:00,  1.10s/it]


Epoch 99/100 - Total Loss: -1885.3819 - Cls: 4.6325 - Pose: 356.8773


Validating: 100%|██████████| 47/47 [00:46<00:00,  1.00it/s]


Validation - Loss: -10.1684 | Cls Loss: 0.5290 | Pose Loss: 83.7949 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:28<00:00,  1.11s/it]


Epoch 100/100 - Total Loss: -1599.9534 - Cls: 246.2531 - Pose: 396.7757


Validating: 100%|██████████| 47/47 [00:48<00:00,  1.02s/it]

Validation - Loss: -9.6938 | Cls Loss: 2.9434 | Pose Loss: 101.0880 | Accuracy: 100.00%





In [58]:
one_pcd = dataset[2][0]

model.eval()
with torch.no_grad():
    # Move the point cloud to the same device as the model
    one_pcd = one_pcd.to(device)
    # Forward pass through the model
    # Note: Add batch dimension
    class_logits, pred_pose = model(one_pcd.unsqueeze(0).to(device))
one_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(one_pcd.cpu().numpy()))
one_pcd.paint_uniform_color([0.5, 0.5, 0.5])

PointCloud with 1000 points.

In [84]:
from pathlib import Path
import open3d as o3d
import numpy as np
model_dir = Path("./data/models")
model_list = list(model_dir.glob("*"))

def check_model_on_instance(model, dataset, idx, device, paths):
    pcd = dataset[idx][0]
    gt_class = dataset[idx][1][0]
    gt_rotation_quaternion = dataset[idx][1][1]
    gt_translation = dataset[idx][1][2]
    gt_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pcd))
    
    model.eval()
    with torch.no_grad():
        # Move the point cloud to the same device as the model
        pcd = pcd.to(device)
        # Forward pass through the model
        # Note: Add batch dimension
        class_logits, pred_pose = model(pcd.unsqueeze(0).to(device))
        preds = class_logits.argmax(dim=1)
        model_path = paths[preds.item()]
        print(f"Predicted class: {preds.item()}", 
              f"GT class: {gt_class.item()}")
        print(f"Model path: {model_path}")
    pcd_predicted = o3d.io.read_triangle_mesh(model_path)
    pcd_predicted.paint_uniform_color([1, 0, 0])  # Red color for predicted model
    predicted_quaternion = pred_pose[0, 3:].cpu().numpy()

    q1, q2, q3, q4 = predicted_quaternion/np.linalg.norm(predicted_quaternion)
    print(np.linalg.norm(np.array([q1, q2, q3, q4])))
    x, y, z = pred_pose[0, :3].cpu().numpy()
    # Convert quaternion to rotation matrix
    T = np.array([
        [1 - 2*(q3**2 + q4**2),     2*(q2*q3 - q1*q4),
            2*(q2*q4 + q1*q3), x],
        [2*(q2*q3 + q1*q4), 1 - 2*(q2**2 + q4**2),     2*(q3*q4 - q1*q2), y],
        [2*(q2*q4 - q1*q3),     2*(q3*q4 + q1*q2), 1 - 2*(q2**2 + q3**2), z],
        [0, 0, 0, 1]
    ])
    pcd_predicted.transform(T)
    translation_diff = np.array([x, y, z]) - gt_translation.numpy()
    print(f"Translation diff: {translation_diff}")
    translation_error = sum((np.array([x, y, z]) - gt_translation.cpu().numpy())**2)/3
    print(f"Translation error: {translation_error}")
    rotation_error = 1 - np.dot(predicted_quaternion, gt_rotation_quaternion)
    print(f"Rotation error: {rotation_error}")
    # Visualize the point cloud and the predicted model
    o3d.visualization.draw_geometries([pcd_predicted, gt_pcd])


In [86]:
check_model_on_instance(model, train_dataset, 0, device, model_list)

Predicted class: 2 GT class: 2
Model path: data\models\hammer.obj
0.99999994
Translation diff: [ 0.14558786 -0.05423252  0.06600547]
Translation error: 0.00949790452917417
Rotation error: -0.47630763053894043


In [59]:
preds = class_logits.argmax(dim=1)
preds.item()

3

In [None]:



model_path = model_list[preds.item()]

In [69]:
q1, q2, q3, q4 = pred_pose[0, 3:].cpu().numpy()
print(q1, q2, q3, q4)

0.10637765 -2.4367216 1.1922573 6.49378


In [None]:

pcd_predicted = o3d.io.read_triangle_mesh(model_path)
pcd_predicted.paint_uniform_color([1, 0, 0])  # Red color for predicted model
translation = pred_pose[0, :3].cpu().numpy()
rotation = pred_pose[0, 3:].cpu().numpy()
pcd_predicted.rotate(pcd_predicted.get_rotation_matrix_from_quaternion(rotation))
pcd_predicted.translate(translation)

pcd_predicted

TriangleMesh with 746 points and 1476 triangles.

In [62]:
((dataset[2][1][2] - translation)**2).sum()/3

tensor(0.0035)

In [64]:
import numpy as np
print(dataset[2][1][1], rotation/np.linalg.norm(rotation))
1-np.dot(dataset[2][1][1], rotation/np.linalg.norm(rotation))

tensor([-0.3478, -0.0184, -0.8122,  0.4679]) [ 0.01511382 -0.3462021   0.16939235  0.92261684]


0.7047482132911682

In [65]:
np.linalg.norm(dataset[1][1][1])

1.0

In [66]:
o3d.visualization.draw_geometries([pcd_predicted, one_pcd])