In [1]:
import torch 
from torch.utils.data import DataLoader
from src import full_pcd_dataset
import importlib
importlib.reload(full_pcd_dataset)
import json
# Create dataset instance
dataset = full_pcd_dataset.FullPCDDataset("data/full_pcd_30000_samples.npz")
with open("data/label_dict.json", "r") as f:
    label_dict = json.load(f)
num_classes = len(label_dict.keys())
# Split into train and validation (example: 80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16,pin_memory=True)
num_classes

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


5

In [2]:
train_dataset[0][1][0]

tensor(3)

In [3]:
from src.model import PoseWithClassModel
from torch.optim.lr_scheduler import StepLR
model = PoseWithClassModel(num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=5, gamma=0.75) 


In [4]:
from src.train_utils import train_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.is_available()
torch.backends.cudnn.version()
torch.cuda.is_available()
torch.cuda.get_device_properties(0)
print(f"Using device: {device}")


Using device: cuda


In [5]:
torch.backends.cudnn.is_available()

True

In [6]:
torch.backends.cudnn.version()

90100

In [7]:
torch.cuda.is_available()

True

In [8]:
torch.cuda.get_device_properties(0)

_CudaDeviceProperties(name='NVIDIA GeForce RTX 4070', major=8, minor=9, total_memory=12281MB, multi_processor_count=46, uuid=9e60166a-e60f-451d-3cb0-2063affabbf5, L2_cache_size=36MB)

In [9]:
train_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=100)

Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 1/100 - Total Loss: 1.8444 - Cls: 1.2625 - Trans: 0.0544 - Rot: 0.6301


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: 3.8496 | Cls Loss: 3.6377 | Trans Loss: 0.1035 | Rot Loss: 0.5773 | Accuracy: 22.82%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 2/100 - Total Loss: 0.6926 - Cls: 0.4456 - Trans: 0.0461 - Rot: 0.7380


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: 0.3326 | Cls Loss: 0.2726 | Trans Loss: 0.0319 | Rot Loss: 0.5764 | Accuracy: 86.53%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 3/100 - Total Loss: 0.1856 - Cls: 0.3749 - Trans: 0.0545 - Rot: 0.8349


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: 0.1486 | Cls Loss: 0.3837 | Trans Loss: 0.0321 | Rot Loss: 0.5745 | Accuracy: 82.82%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 4/100 - Total Loss: -0.3342 - Cls: 0.2729 - Trans: 0.0590 - Rot: 0.9087


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -0.5700 | Cls Loss: 0.1360 | Trans Loss: 0.0255 | Rot Loss: 0.5765 | Accuracy: 93.93%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 5/100 - Total Loss: -0.8276 - Cls: 0.1985 - Trans: 0.0672 - Rot: 0.9556


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -0.7447 | Cls Loss: 0.1821 | Trans Loss: 0.0540 | Rot Loss: 0.5719 | Accuracy: 92.53%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 6/100 - Total Loss: -1.2044 - Cls: 0.2149 - Trans: 0.0738 - Rot: 0.9821


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -1.4854 | Cls Loss: 0.0339 | Trans Loss: 0.0340 | Rot Loss: 0.5716 | Accuracy: 98.92%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 7/100 - Total Loss: -1.6007 - Cls: 0.2028 - Trans: 0.0889 - Rot: 0.9893


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -1.8639 | Cls Loss: 0.0421 | Trans Loss: 0.0235 | Rot Loss: 0.5595 | Accuracy: 98.45%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 8/100 - Total Loss: -2.0586 - Cls: 0.1673 - Trans: 0.0944 - Rot: 0.9757


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -1.9319 | Cls Loss: 0.0934 | Trans Loss: 0.0403 | Rot Loss: 0.5522 | Accuracy: 96.62%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 9/100 - Total Loss: -2.4482 - Cls: 0.1992 - Trans: 0.1004 - Rot: 0.9696


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -1.6151 | Cls Loss: 0.2084 | Trans Loss: 0.0462 | Rot Loss: 0.5339 | Accuracy: 92.32%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 10/100 - Total Loss: -2.6198 - Cls: 0.3441 - Trans: 0.1323 - Rot: 0.9929


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.1741 | Cls Loss: 0.0071 | Trans Loss: 0.0156 | Rot Loss: 0.4964 | Accuracy: 99.77%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 11/100 - Total Loss: -3.2672 - Cls: 0.1318 - Trans: 0.1198 - Rot: 0.9781


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.4925 | Cls Loss: 0.0113 | Trans Loss: 0.0199 | Rot Loss: 0.4850 | Accuracy: 99.63%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 12/100 - Total Loss: -3.5657 - Cls: 0.1932 - Trans: 0.1245 - Rot: 0.9971


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.7467 | Cls Loss: 0.0148 | Trans Loss: 0.0141 | Rot Loss: 0.5302 | Accuracy: 99.40%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 13/100 - Total Loss: -3.5831 - Cls: 0.4741 - Trans: 0.1750 - Rot: 0.9987


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.6714 | Cls Loss: 0.0176 | Trans Loss: 0.0492 | Rot Loss: 0.4937 | Accuracy: 99.43%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 14/100 - Total Loss: -4.2429 - Cls: 0.1581 - Trans: 0.1872 - Rot: 0.9907


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -3.6109 | Cls Loss: 0.0751 | Trans Loss: 0.0199 | Rot Loss: 0.4824 | Accuracy: 97.72%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 15/100 - Total Loss: -4.4966 - Cls: 0.2436 - Trans: 0.1890 - Rot: 1.0015


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -4.6748 | Cls Loss: 0.0079 | Trans Loss: 0.0208 | Rot Loss: 0.4857 | Accuracy: 99.75%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 16/100 - Total Loss: -4.7327 - Cls: 0.3187 - Trans: 0.2405 - Rot: 0.9957


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -5.1077 | Cls Loss: 0.0053 | Trans Loss: 0.0116 | Rot Loss: 0.4960 | Accuracy: 99.85%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 17/100 - Total Loss: -4.9358 - Cls: 0.3947 - Trans: 0.2774 - Rot: 0.9998


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -5.0594 | Cls Loss: 0.0131 | Trans Loss: 0.0190 | Rot Loss: 0.4911 | Accuracy: 99.57%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 18/100 - Total Loss: -5.3976 - Cls: 0.2374 - Trans: 0.2790 - Rot: 0.9969


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -5.7893 | Cls Loss: 0.0053 | Trans Loss: 0.0067 | Rot Loss: 0.4758 | Accuracy: 99.77%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 19/100 - Total Loss: -5.0265 - Cls: 0.7059 - Trans: 0.3969 - Rot: 1.0057


Validating: 100%|██████████| 47/47 [00:46<00:00,  1.01it/s]


Validation - Loss: -5.9606 | Cls Loss: 0.0052 | Trans Loss: 0.0056 | Rot Loss: 0.4887 | Accuracy: 99.82%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 20/100 - Total Loss: -5.5496 - Cls: 0.4793 - Trans: 0.3687 - Rot: 0.9974


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -6.3511 | Cls Loss: 0.0018 | Trans Loss: 0.0030 | Rot Loss: 0.4862 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 21/100 - Total Loss: -6.2466 - Cls: 0.0843 - Trans: 0.3780 - Rot: 0.9921


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -6.1699 | Cls Loss: 0.0090 | Trans Loss: 0.0105 | Rot Loss: 0.4701 | Accuracy: 99.53%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 22/100 - Total Loss: -5.5391 - Cls: 0.7512 - Trans: 0.6121 - Rot: 1.0045


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -6.5002 | Cls Loss: 0.0090 | Trans Loss: 0.0042 | Rot Loss: 0.4876 | Accuracy: 99.73%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 23/100 - Total Loss: -6.5136 - Cls: 0.1056 - Trans: 0.5298 - Rot: 0.9991


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -6.8560 | Cls Loss: 0.0015 | Trans Loss: 0.0077 | Rot Loss: 0.4710 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 24/100 - Total Loss: -6.6654 - Cls: 0.2123 - Trans: 0.5613 - Rot: 0.9975


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -6.9487 | Cls Loss: 0.0044 | Trans Loss: 0.0071 | Rot Loss: 0.4829 | Accuracy: 99.80%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 25/100 - Total Loss: -6.2966 - Cls: 0.6332 - Trans: 0.6833 - Rot: 1.0079


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -6.9529 | Cls Loss: 0.0041 | Trans Loss: 0.0090 | Rot Loss: 0.4829 | Accuracy: 99.88%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.08s/it]


Epoch 26/100 - Total Loss: -6.2388 - Cls: 0.8417 - Trans: 0.6748 - Rot: 0.9957


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -7.5417 | Cls Loss: 0.0017 | Trans Loss: 0.0026 | Rot Loss: 0.4935 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 27/100 - Total Loss: -7.0861 - Cls: 0.1751 - Trans: 0.7090 - Rot: 0.9956


Validating: 100%|██████████| 47/47 [00:46<00:00,  1.02it/s]


Validation - Loss: -7.5445 | Cls Loss: 0.0013 | Trans Loss: 0.0064 | Rot Loss: 0.4575 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 28/100 - Total Loss: -7.3351 - Cls: 0.1252 - Trans: 0.7195 - Rot: 0.9987


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -7.4426 | Cls Loss: 0.0035 | Trans Loss: 0.0074 | Rot Loss: 0.4771 | Accuracy: 99.85%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 29/100 - Total Loss: -7.0318 - Cls: 0.5431 - Trans: 0.8014 - Rot: 1.0115


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -7.5021 | Cls Loss: 0.0051 | Trans Loss: 0.0061 | Rot Loss: 0.4943 | Accuracy: 99.83%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 30/100 - Total Loss: -6.4188 - Cls: 1.0496 - Trans: 0.9443 - Rot: 0.9986


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.1449 | Cls Loss: 0.0009 | Trans Loss: 0.0026 | Rot Loss: 0.4751 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 31/100 - Total Loss: -7.2895 - Cls: 0.4526 - Trans: 0.7670 - Rot: 1.0021


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -4.7477 | Cls Loss: 0.0445 | Trans Loss: 0.0058 | Rot Loss: 0.5217 | Accuracy: 98.33%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 32/100 - Total Loss: -7.3652 - Cls: 0.4481 - Trans: 0.8226 - Rot: 0.9947


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -7.8268 | Cls Loss: 0.0046 | Trans Loss: 0.0050 | Rot Loss: 0.4843 | Accuracy: 99.85%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 33/100 - Total Loss: -7.6860 - Cls: 0.2685 - Trans: 0.8455 - Rot: 0.9915


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.1907 | Cls Loss: 0.0005 | Trans Loss: 0.0061 | Rot Loss: 0.4563 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 34/100 - Total Loss: -4.9197 - Cls: 2.6240 - Trans: 1.2085 - Rot: 1.0142


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -7.9957 | Cls Loss: 0.0034 | Trans Loss: 0.0045 | Rot Loss: 0.4669 | Accuracy: 99.93%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 35/100 - Total Loss: -7.5906 - Cls: 0.2644 - Trans: 0.8772 - Rot: 0.9910


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -8.1918 | Cls Loss: 0.0016 | Trans Loss: 0.0045 | Rot Loss: 0.4649 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.07s/it]


Epoch 36/100 - Total Loss: -7.6873 - Cls: 0.3043 - Trans: 0.8445 - Rot: 0.9999


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -8.6141 | Cls Loss: 0.0007 | Trans Loss: 0.0020 | Rot Loss: 0.4628 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 37/100 - Total Loss: -7.8933 - Cls: 0.1546 - Trans: 0.8822 - Rot: 1.0115


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.6199 | Cls Loss: 0.0003 | Trans Loss: 0.0017 | Rot Loss: 0.5407 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 38/100 - Total Loss: -7.8801 - Cls: 0.1341 - Trans: 0.9650 - Rot: 0.9987


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -7.7394 | Cls Loss: 0.0058 | Trans Loss: 0.0062 | Rot Loss: 0.4983 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 39/100 - Total Loss: -7.2362 - Cls: 0.8240 - Trans: 0.9941 - Rot: 0.9968


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.0844 | Cls Loss: 0.0046 | Trans Loss: 0.0046 | Rot Loss: 0.4811 | Accuracy: 99.82%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 40/100 - Total Loss: -8.0440 - Cls: 0.1785 - Trans: 0.9189 - Rot: 0.9939


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.5967 | Cls Loss: 0.0005 | Trans Loss: 0.0047 | Rot Loss: 0.4689 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 41/100 - Total Loss: -8.0120 - Cls: 0.3415 - Trans: 0.8853 - Rot: 1.0040


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -5.5098 | Cls Loss: 0.0291 | Trans Loss: 0.0038 | Rot Loss: 0.4790 | Accuracy: 99.27%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 42/100 - Total Loss: -7.5268 - Cls: 0.8281 - Trans: 0.9331 - Rot: 0.9984


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.4903 | Cls Loss: 0.0013 | Trans Loss: 0.0056 | Rot Loss: 0.4672 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 43/100 - Total Loss: -7.8462 - Cls: 0.5808 - Trans: 0.9296 - Rot: 1.0007


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.4916 | Cls Loss: 0.0012 | Trans Loss: 0.0057 | Rot Loss: 0.4885 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 44/100 - Total Loss: -7.7593 - Cls: 0.6945 - Trans: 0.9639 - Rot: 0.9975


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -7.6139 | Cls Loss: 0.0067 | Trans Loss: 0.0076 | Rot Loss: 0.4769 | Accuracy: 99.70%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 45/100 - Total Loss: -7.4010 - Cls: 1.0586 - Trans: 0.9622 - Rot: 1.0093


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -8.9176 | Cls Loss: 0.0012 | Trans Loss: 0.0026 | Rot Loss: 0.4951 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 46/100 - Total Loss: -8.2812 - Cls: 0.2450 - Trans: 0.9669 - Rot: 0.9934


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -8.1627 | Cls Loss: 0.0019 | Trans Loss: 0.0083 | Rot Loss: 0.4729 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 47/100 - Total Loss: -8.3692 - Cls: 0.2266 - Trans: 0.9858 - Rot: 0.9978


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -8.5169 | Cls Loss: 0.0023 | Trans Loss: 0.0058 | Rot Loss: 0.4723 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 48/100 - Total Loss: -8.2947 - Cls: 0.4428 - Trans: 0.9355 - Rot: 1.0004


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.0146 | Cls Loss: 0.0009 | Trans Loss: 0.0040 | Rot Loss: 0.4697 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 49/100 - Total Loss: -8.6852 - Cls: 0.1551 - Trans: 0.9270 - Rot: 0.9978


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.4983 | Cls Loss: 0.0003 | Trans Loss: 0.0089 | Rot Loss: 0.4817 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 50/100 - Total Loss: -6.4690 - Cls: 1.9770 - Trans: 1.2319 - Rot: 1.0095


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.9423 | Cls Loss: 0.0008 | Trans Loss: 0.0044 | Rot Loss: 0.4836 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 51/100 - Total Loss: -8.3723 - Cls: 0.3932 - Trans: 0.9294 - Rot: 0.9911


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -8.5192 | Cls Loss: 0.0026 | Trans Loss: 0.0058 | Rot Loss: 0.4829 | Accuracy: 99.93%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 52/100 - Total Loss: -8.6628 - Cls: 0.1598 - Trans: 0.9331 - Rot: 1.0068


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -9.1722 | Cls Loss: 0.0008 | Trans Loss: 0.0039 | Rot Loss: 0.4682 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 53/100 - Total Loss: -8.9840 - Cls: 0.0646 - Trans: 0.8616 - Rot: 0.9943


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.6171 | Cls Loss: 0.0001 | Trans Loss: 0.0023 | Rot Loss: 0.4749 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 54/100 - Total Loss: -8.2225 - Cls: 0.8127 - Trans: 0.9653 - Rot: 1.0054


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -6.0572 | Cls Loss: 0.0167 | Trans Loss: 0.0046 | Rot Loss: 0.4908 | Accuracy: 99.43%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 55/100 - Total Loss: -7.6354 - Cls: 1.2207 - Trans: 1.0966 - Rot: 0.9951


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -9.2811 | Cls Loss: 0.0002 | Trans Loss: 0.0046 | Rot Loss: 0.4636 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 56/100 - Total Loss: -8.4877 - Cls: 0.3187 - Trans: 1.1473 - Rot: 1.0010


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -9.2741 | Cls Loss: 0.0006 | Trans Loss: 0.0042 | Rot Loss: 0.4682 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 57/100 - Total Loss: -6.6640 - Cls: 2.2371 - Trans: 1.0874 - Rot: 1.0078


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: 356.0917 | Cls Loss: 1.8163 | Trans Loss: 0.0501 | Rot Loss: 0.5355 | Accuracy: 65.58%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 58/100 - Total Loss: -7.4917 - Cls: 1.3238 - Trans: 1.0186 - Rot: 0.9990


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.4032 | Cls Loss: 0.0002 | Trans Loss: 0.0033 | Rot Loss: 0.4778 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 59/100 - Total Loss: -8.9174 - Cls: 0.1508 - Trans: 0.8358 - Rot: 0.9948


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.4001 | Cls Loss: 0.0001 | Trans Loss: 0.0039 | Rot Loss: 0.4757 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 60/100 - Total Loss: -8.9656 - Cls: 0.0525 - Trans: 0.9660 - Rot: 0.9996


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.4780 | Cls Loss: 0.0000 | Trans Loss: 0.0037 | Rot Loss: 0.4765 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 61/100 - Total Loss: -9.1502 - Cls: 0.0692 - Trans: 0.8291 - Rot: 1.0033


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.5971 | Cls Loss: 0.0001 | Trans Loss: 0.0033 | Rot Loss: 0.4778 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 62/100 - Total Loss: -8.8265 - Cls: 0.4186 - Trans: 0.8951 - Rot: 0.9957


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.6866 | Cls Loss: 0.0007 | Trans Loss: 0.0022 | Rot Loss: 0.4743 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 63/100 - Total Loss: -9.0849 - Cls: 0.1796 - Trans: 0.9449 - Rot: 1.0009


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.5935 | Cls Loss: 0.0001 | Trans Loss: 0.0040 | Rot Loss: 0.4748 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 64/100 - Total Loss: -9.1258 - Cls: 0.1434 - Trans: 0.9985 - Rot: 0.9976


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.9100 | Cls Loss: 0.0001 | Trans Loss: 0.0024 | Rot Loss: 0.4719 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 65/100 - Total Loss: -9.2606 - Cls: 0.0589 - Trans: 1.0070 - Rot: 1.0002


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.8895 | Cls Loss: 0.0003 | Trans Loss: 0.0025 | Rot Loss: 0.4702 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 66/100 - Total Loss: -6.5925 - Cls: 2.5252 - Trans: 1.1553 - Rot: 1.0031


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.4153 | Cls Loss: 0.0015 | Trans Loss: 0.0030 | Rot Loss: 0.4804 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 67/100 - Total Loss: -8.0478 - Cls: 1.1556 - Trans: 0.9947 - Rot: 0.9991


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -9.9926 | Cls Loss: 0.0002 | Trans Loss: 0.0012 | Rot Loss: 0.4730 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 68/100 - Total Loss: -9.3369 - Cls: 0.0616 - Trans: 0.8629 - Rot: 0.9995


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.06it/s]


Validation - Loss: -9.7703 | Cls Loss: 0.0003 | Trans Loss: 0.0028 | Rot Loss: 0.4919 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 69/100 - Total Loss: -9.3296 - Cls: 0.1491 - Trans: 0.8625 - Rot: 0.9978


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -10.1480 | Cls Loss: 0.0002 | Trans Loss: 0.0013 | Rot Loss: 0.4708 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 70/100 - Total Loss: -8.8887 - Cls: 0.4709 - Trans: 1.0635 - Rot: 1.0058


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -7.6740 | Cls Loss: 0.0082 | Trans Loss: 0.0030 | Rot Loss: 0.4917 | Accuracy: 99.70%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 71/100 - Total Loss: -6.3118 - Cls: 2.6796 - Trans: 1.3136 - Rot: 0.9989


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -8.6365 | Cls Loss: 0.0004 | Trans Loss: 0.0104 | Rot Loss: 0.4765 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 72/100 - Total Loss: -9.1778 - Cls: 0.1419 - Trans: 0.9429 - Rot: 0.9965


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.7176 | Cls Loss: 0.0004 | Trans Loss: 0.0033 | Rot Loss: 0.4760 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:22<00:00,  1.07s/it]


Epoch 73/100 - Total Loss: -7.9939 - Cls: 1.3010 - Trans: 1.0311 - Rot: 1.0059


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -2.7225 | Cls Loss: 0.0175 | Trans Loss: 0.0201 | Rot Loss: 0.4986 | Accuracy: 99.48%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 74/100 - Total Loss: -7.9429 - Cls: 1.3322 - Trans: 0.9684 - Rot: 1.0014


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.8798 | Cls Loss: 0.0005 | Trans Loss: 0.0019 | Rot Loss: 0.4768 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 75/100 - Total Loss: -9.2797 - Cls: 0.2057 - Trans: 0.8306 - Rot: 0.9951


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.8206 | Cls Loss: 0.0001 | Trans Loss: 0.0035 | Rot Loss: 0.4722 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 76/100 - Total Loss: -9.2856 - Cls: 0.2126 - Trans: 0.8869 - Rot: 0.9988


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -9.9466 | Cls Loss: 0.0001 | Trans Loss: 0.0029 | Rot Loss: 0.4777 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 77/100 - Total Loss: -9.4628 - Cls: 0.1460 - Trans: 0.8704 - Rot: 1.0025


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -10.1130 | Cls Loss: 0.0001 | Trans Loss: 0.0022 | Rot Loss: 0.4803 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 78/100 - Total Loss: -8.5749 - Cls: 0.9364 - Trans: 1.0402 - Rot: 0.9985


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.8294 | Cls Loss: 0.0016 | Trans Loss: 0.0074 | Rot Loss: 0.4839 | Accuracy: 99.93%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 79/100 - Total Loss: -9.0735 - Cls: 0.4964 - Trans: 0.9530 - Rot: 1.0024


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -10.2693 | Cls Loss: 0.0001 | Trans Loss: 0.0016 | Rot Loss: 0.4752 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 80/100 - Total Loss: -9.1974 - Cls: 0.4519 - Trans: 0.9368 - Rot: 0.9964


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -10.0906 | Cls Loss: 0.0003 | Trans Loss: 0.0026 | Rot Loss: 0.4723 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 81/100 - Total Loss: -9.0688 - Cls: 0.6029 - Trans: 0.9740 - Rot: 1.0025


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.5561 | Cls Loss: 0.0012 | Trans Loss: 0.0098 | Rot Loss: 0.4805 | Accuracy: 99.93%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 82/100 - Total Loss: -9.3689 - Cls: 0.2814 - Trans: 1.0154 - Rot: 0.9942


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -10.4597 | Cls Loss: 0.0000 | Trans Loss: 0.0014 | Rot Loss: 0.4705 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 83/100 - Total Loss: -9.4480 - Cls: 0.3224 - Trans: 0.9905 - Rot: 1.0002


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: 14.1669 | Cls Loss: 0.0624 | Trans Loss: 0.0146 | Rot Loss: 0.4977 | Accuracy: 98.12%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 84/100 - Total Loss: -7.6854 - Cls: 1.7552 - Trans: 1.2032 - Rot: 1.0014


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.5637 | Cls Loss: 0.0025 | Trans Loss: 0.0076 | Rot Loss: 0.4812 | Accuracy: 99.92%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 85/100 - Total Loss: -8.7334 - Cls: 0.9553 - Trans: 0.9802 - Rot: 1.0000


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -9.1556 | Cls Loss: 0.0003 | Trans Loss: 0.0086 | Rot Loss: 0.4741 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:23<00:00,  1.08s/it]


Epoch 86/100 - Total Loss: -9.4459 - Cls: 0.2002 - Trans: 1.0348 - Rot: 0.9991


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -9.8514 | Cls Loss: 0.0011 | Trans Loss: 0.0030 | Rot Loss: 0.4714 | Accuracy: 99.95%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 87/100 - Total Loss: -9.7371 - Cls: 0.0446 - Trans: 0.9660 - Rot: 0.9988


Validating: 100%|██████████| 47/47 [00:47<00:00,  1.00s/it]


Validation - Loss: -10.4940 | Cls Loss: 0.0001 | Trans Loss: 0.0017 | Rot Loss: 0.4684 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 88/100 - Total Loss: -7.5697 - Cls: 2.2966 - Trans: 0.9560 - Rot: 1.0058


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -4.2215 | Cls Loss: 0.0154 | Trans Loss: 0.0046 | Rot Loss: 0.4746 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 89/100 - Total Loss: -9.0153 - Cls: 0.7454 - Trans: 0.9846 - Rot: 1.0046


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -4.0389 | Cls Loss: 0.0137 | Trans Loss: 0.0099 | Rot Loss: 0.4781 | Accuracy: 99.55%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 90/100 - Total Loss: -9.4943 - Cls: 0.2977 - Trans: 0.9846 - Rot: 0.9915


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.8265 | Cls Loss: 0.0015 | Trans Loss: 0.0025 | Rot Loss: 0.4685 | Accuracy: 99.97%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 91/100 - Total Loss: -8.1905 - Cls: 1.5666 - Trans: 1.0644 - Rot: 1.0025


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -2.5393 | Cls Loss: 0.0143 | Trans Loss: 0.0172 | Rot Loss: 0.5102 | Accuracy: 99.55%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 92/100 - Total Loss: -7.9965 - Cls: 1.5971 - Trans: 1.1059 - Rot: 0.9973


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -9.7561 | Cls Loss: 0.0001 | Trans Loss: 0.0059 | Rot Loss: 0.4678 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 93/100 - Total Loss: -9.3831 - Cls: 0.4056 - Trans: 0.9341 - Rot: 1.0050


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -8.3989 | Cls Loss: 0.0057 | Trans Loss: 0.0015 | Rot Loss: 0.4728 | Accuracy: 99.78%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 94/100 - Total Loss: -9.8183 - Cls: 0.1089 - Trans: 0.8640 - Rot: 0.9959


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -10.2600 | Cls Loss: 0.0000 | Trans Loss: 0.0033 | Rot Loss: 0.4692 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 95/100 - Total Loss: -9.3366 - Cls: 0.5632 - Trans: 0.9420 - Rot: 1.0022


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]


Validation - Loss: -10.1550 | Cls Loss: 0.0001 | Trans Loss: 0.0039 | Rot Loss: 0.4700 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 96/100 - Total Loss: -9.1522 - Cls: 0.7313 - Trans: 0.9874 - Rot: 0.9979


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -10.2569 | Cls Loss: 0.0001 | Trans Loss: 0.0034 | Rot Loss: 0.4760 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 97/100 - Total Loss: -9.2228 - Cls: 0.5857 - Trans: 1.0695 - Rot: 1.0006


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.04it/s]


Validation - Loss: -10.2542 | Cls Loss: 0.0003 | Trans Loss: 0.0030 | Rot Loss: 0.4720 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 98/100 - Total Loss: -9.7471 - Cls: 0.2354 - Trans: 0.9376 - Rot: 0.9988


Validating: 100%|██████████| 47/47 [00:45<00:00,  1.03it/s]


Validation - Loss: -10.4524 | Cls Loss: 0.0004 | Trans Loss: 0.0019 | Rot Loss: 0.4662 | Accuracy: 99.98%


Training: 100%|██████████| 188/188 [03:21<00:00,  1.07s/it]


Epoch 99/100 - Total Loss: -8.4545 - Cls: 1.4271 - Trans: 1.0266 - Rot: 0.9989


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.04it/s]


Validation - Loss: -10.6380 | Cls Loss: 0.0001 | Trans Loss: 0.0015 | Rot Loss: 0.4657 | Accuracy: 100.00%


Training: 100%|██████████| 188/188 [03:20<00:00,  1.07s/it]


Epoch 100/100 - Total Loss: -9.1872 - Cls: 0.7049 - Trans: 1.0262 - Rot: 1.0074


Validating: 100%|██████████| 47/47 [00:44<00:00,  1.05it/s]

Validation - Loss: -7.9882 | Cls Loss: 0.0052 | Trans Loss: 0.0043 | Rot Loss: 0.4862 | Accuracy: 99.80%





In [None]:
one_pcd = dataset[2][0]

model.eval()
with torch.no_grad():
    # Move the point cloud to the same device as the model
    one_pcd = one_pcd.to(device)
    # Forward pass through the model
    # Note: Add batch dimension
    class_logits, pred_pose = model(one_pcd.unsqueeze(0).to(device))
one_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(one_pcd.cpu().numpy()))
one_pcd.paint_uniform_color([0.5, 0.5, 0.5])

PointCloud with 1000 points.

In [None]:
from pathlib import Path
import open3d as o3d
import numpy as np
model_dir = Path("./data/models")
model_list = list(model_dir.glob("*"))

def check_model_on_instance(model, dataset, idx, device, paths):
    pcd = dataset[idx][0]
    gt_class = dataset[idx][1][0]
    gt_rotation_quaternion = dataset[idx][1][1]
    gt_translation = dataset[idx][1][2]
    gt_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pcd))
    
    model.eval()
    with torch.no_grad():
        # Move the point cloud to the same device as the model
        pcd = pcd.to(device)
        # Forward pass through the model
        # Note: Add batch dimension
        class_logits, pred_pose = model(pcd.unsqueeze(0).to(device))
        preds = class_logits.argmax(dim=1)
        model_path = paths[preds.item()]
        print(f"Predicted class: {preds.item()}", 
              f"GT class: {gt_class.item()}")
        print(f"Model path: {model_path}")
    pcd_predicted = o3d.io.read_triangle_mesh(model_path)
    pcd_predicted.paint_uniform_color([1, 0, 0])  # Red color for predicted model
    predicted_quaternion = pred_pose[0, 3:].cpu().numpy()

    q1, q2, q3, q4 = predicted_quaternion/np.linalg.norm(predicted_quaternion)
    print(np.linalg.norm(np.array([q1, q2, q3, q4])))
    x, y, z = pred_pose[0, :3].cpu().numpy()
    # Convert quaternion to rotation matrix
    T = np.array([
        [1 - 2*(q3**2 + q4**2),     2*(q2*q3 - q1*q4),
            2*(q2*q4 + q1*q3), x],
        [2*(q2*q3 + q1*q4), 1 - 2*(q2**2 + q4**2),     2*(q3*q4 - q1*q2), y],
        [2*(q2*q4 - q1*q3),     2*(q3*q4 + q1*q2), 1 - 2*(q2**2 + q3**2), z],
        [0, 0, 0, 1]
    ])
    pcd_predicted.transform(T)
    translation_diff = np.array([x, y, z]) - gt_translation.numpy()
    print(f"Translation diff: {translation_diff}")
    translation_error = sum((np.array([x, y, z]) - gt_translation.cpu().numpy())**2)/3
    print(f"Translation error: {translation_error}")
    rotation_error = 1 - np.dot(predicted_quaternion, gt_rotation_quaternion)
    print(f"Rotation error: {rotation_error}")
    # Visualize the point cloud and the predicted model
    o3d.visualization.draw_geometries([pcd_predicted, gt_pcd])


In [None]:
check_model_on_instance(model, train_dataset, 0, device, model_list)

Predicted class: 2 GT class: 2
Model path: data\models\hammer.obj
0.99999994
Translation diff: [ 0.14558786 -0.05423252  0.06600547]
Translation error: 0.00949790452917417
Rotation error: -0.47630763053894043


In [None]:
preds = class_logits.argmax(dim=1)
preds.item()

3

In [None]:



model_path = model_list[preds.item()]

In [None]:
q1, q2, q3, q4 = pred_pose[0, 3:].cpu().numpy()
print(q1, q2, q3, q4)

0.10637765 -2.4367216 1.1922573 6.49378


In [None]:

pcd_predicted = o3d.io.read_triangle_mesh(model_path)
pcd_predicted.paint_uniform_color([1, 0, 0])  # Red color for predicted model
translation = pred_pose[0, :3].cpu().numpy()
rotation = pred_pose[0, 3:].cpu().numpy()
pcd_predicted.rotate(pcd_predicted.get_rotation_matrix_from_quaternion(rotation))
pcd_predicted.translate(translation)

pcd_predicted

TriangleMesh with 746 points and 1476 triangles.

In [None]:
((dataset[2][1][2] - translation)**2).sum()/3

tensor(0.0035)

In [None]:
import numpy as np
print(dataset[2][1][1], rotation/np.linalg.norm(rotation))
1-np.dot(dataset[2][1][1], rotation/np.linalg.norm(rotation))

tensor([-0.3478, -0.0184, -0.8122,  0.4679]) [ 0.01511382 -0.3462021   0.16939235  0.92261684]


0.7047482132911682

In [None]:
np.linalg.norm(dataset[1][1][1])

1.0

In [None]:
o3d.visualization.draw_geometries([pcd_predicted, one_pcd])