In [1]:
import torch 
from torch.utils.data import DataLoader
from src import full_pcd_dataset
import importlib
importlib.reload(full_pcd_dataset)
import json
# Create dataset instance
dataset = full_pcd_dataset.FullPCDDataset("data/full_pcd_100000_samples_6d.npz")
with open("data/label_dict.json", "r") as f:
    label_dict = json.load(f)
num_classes = len(label_dict.keys())
# Split into train and validation (example: 80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8,pin_memory=True)
num_classes

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


5

In [2]:
from src.model_translation import TranslationModel 
from torch.optim.lr_scheduler import StepLR
model = TranslationModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5) 

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
print(count_parameters(model))

276899


In [5]:
from src.train_position import train_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
train_model(model, train_loader, val_loader, optimizer, scheduler, device, epochs=100, directory="new_run_translation")

Training: 100%|██████████| 157/157 [02:50<00:00,  1.08s/it]


Epoch 1/100 - Total Loss: 0.2105


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.42it/s]


Validation - Loss: 0.0081 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 2/100 - Total Loss: 0.0442


Validating: 100%|██████████| 40/40 [00:30<00:00,  1.31it/s]


Validation - Loss: 0.0059 | 


Training: 100%|██████████| 157/157 [02:51<00:00,  1.09s/it]


Epoch 3/100 - Total Loss: 0.0353


Validating: 100%|██████████| 40/40 [00:27<00:00,  1.48it/s]


Validation - Loss: 0.0074 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 4/100 - Total Loss: 0.0313


Validating: 100%|██████████| 40/40 [00:27<00:00,  1.47it/s]


Validation - Loss: 0.0060 | 


Training: 100%|██████████| 157/157 [02:53<00:00,  1.11s/it]


Epoch 5/100 - Total Loss: 0.0286


Validating: 100%|██████████| 40/40 [00:29<00:00,  1.38it/s]


Validation - Loss: 0.0084 | 


Training: 100%|██████████| 157/157 [02:55<00:00,  1.12s/it]


Epoch 6/100 - Total Loss: 0.0262


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.40it/s]


Validation - Loss: 0.0038 | 


Training: 100%|██████████| 157/157 [02:53<00:00,  1.11s/it]


Epoch 7/100 - Total Loss: 0.0251


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.40it/s]


Validation - Loss: 0.0134 | 


Training: 100%|██████████| 157/157 [02:51<00:00,  1.09s/it]


Epoch 8/100 - Total Loss: 0.0240


Validating: 100%|██████████| 40/40 [00:31<00:00,  1.27it/s]


Validation - Loss: 0.0025 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.09s/it]


Epoch 9/100 - Total Loss: 0.0230


Validating: 100%|██████████| 40/40 [00:27<00:00,  1.47it/s]


Validation - Loss: 0.0166 | 


Training: 100%|██████████| 157/157 [02:52<00:00,  1.10s/it]


Epoch 10/100 - Total Loss: 0.0223


Validating: 100%|██████████| 40/40 [00:27<00:00,  1.45it/s]


Validation - Loss: 0.0031 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.08s/it]


Epoch 11/100 - Total Loss: 0.0218


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.40it/s]


Validation - Loss: 0.0042 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 12/100 - Total Loss: 0.0211


Validating: 100%|██████████| 40/40 [00:29<00:00,  1.34it/s]


Validation - Loss: 0.0035 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.08s/it]


Epoch 13/100 - Total Loss: 0.0212


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.40it/s]


Validation - Loss: 0.0035 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.08s/it]


Epoch 14/100 - Total Loss: 0.0210


Validating: 100%|██████████| 40/40 [00:27<00:00,  1.46it/s]


Validation - Loss: 0.0084 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 15/100 - Total Loss: 0.0207


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.51it/s]


Validation - Loss: 0.0050 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 16/100 - Total Loss: 0.0197


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.49it/s]


Validation - Loss: 0.0015 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 17/100 - Total Loss: 0.0200


Validating: 100%|██████████| 40/40 [00:29<00:00,  1.34it/s]


Validation - Loss: 0.0091 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 18/100 - Total Loss: 0.0193


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.49it/s]


Validation - Loss: 0.0040 | 


Training: 100%|██████████| 157/157 [02:51<00:00,  1.09s/it]


Epoch 19/100 - Total Loss: 0.0192


Validating: 100%|██████████| 40/40 [00:30<00:00,  1.32it/s]


Validation - Loss: 0.0037 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.09s/it]


Epoch 20/100 - Total Loss: 0.0187


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.41it/s]


Validation - Loss: 0.0067 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 21/100 - Total Loss: 0.0186


Validating: 100%|██████████| 40/40 [00:30<00:00,  1.33it/s]


Validation - Loss: 0.0041 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 22/100 - Total Loss: 0.0184


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.51it/s]


Validation - Loss: 0.0048 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 23/100 - Total Loss: 0.0182


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.51it/s]


Validation - Loss: 0.0135 | 


Training: 100%|██████████| 157/157 [02:47<00:00,  1.07s/it]


Epoch 24/100 - Total Loss: 0.0179


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.52it/s]


Validation - Loss: 0.0026 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.09s/it]


Epoch 25/100 - Total Loss: 0.0175


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.52it/s]


Validation - Loss: 0.0034 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 26/100 - Total Loss: 0.0177


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.51it/s]


Validation - Loss: 0.0039 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 27/100 - Total Loss: 0.0174


Validating: 100%|██████████| 40/40 [00:29<00:00,  1.37it/s]


Validation - Loss: 0.0050 | 


Training: 100%|██████████| 157/157 [02:50<00:00,  1.09s/it]


Epoch 28/100 - Total Loss: 0.0174


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.53it/s]


Validation - Loss: 0.0092 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 29/100 - Total Loss: 0.0172


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.49it/s]


Validation - Loss: 0.0093 | 


Training: 100%|██████████| 157/157 [02:48<00:00,  1.07s/it]


Epoch 30/100 - Total Loss: 0.0171


Validating: 100%|██████████| 40/40 [00:28<00:00,  1.39it/s]


Validation - Loss: 0.0018 | 


Training: 100%|██████████| 157/157 [02:49<00:00,  1.08s/it]


Epoch 31/100 - Total Loss: 0.0170


Validating: 100%|██████████| 40/40 [00:26<00:00,  1.48it/s]


Validation - Loss: 0.0037 | 


Training:  48%|████▊     | 75/157 [01:33<01:42,  1.25s/it]


KeyboardInterrupt: 