In [1]:
from PointUNet import *
from PointNet2FPUNet import *
import pickle

In [2]:
def train_model(model, train_loader, optimizer, num_epochs=10, device='cuda', save_model=True, save_path="pointnetpp_unet.pth"):
    model.to(device)
    model.train()
    loss_history = []

    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for points, labels in progress_bar:
            points, labels = points.to(device), labels.to(device)
            optimizer.zero_grad()

            embeddings = model(points)  # (B, N, output_dim)
            # Compute the discriminative loss
            # Reduce γ (reg term), Use δ_v = 0.3 and δ_d = 1.0
            loss = discriminative_loss(embeddings, labels, delta_v=0.3, delta_d=1.0,
                        alpha=1.0, beta=1.0, gamma=0.0001)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)
        print(f"✅ Epoch {epoch + 1}: Loss = {avg_loss:.4f}")

    if save_model:
        torch.save(model.state_dict(), save_path)
        print("💾 Model saved as "+save_path)

    return loss_history

In [5]:
# Define dataset paths
points_folder = "data/roofNTNU/train_test_split/points_train"
labels_folder = "data/roofNTNU/train_test_split/labels_train"

# Create dataset instance
train_dataset = LiDARPointCloudDataset(points_folder, labels_folder, max_points=512, mode="train")

# Create DataLoader
# train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

train_loader = DataLoader(
    train_dataset, batch_size=4, shuffle=True,
    num_workers=8,  # Use multiple CPU cores for faster loading
    pin_memory=True,  # Optimizes GPU transfers
    collate_fn=collate_fn
)

# Check batch
for points, labels in train_loader:
    print("Batch Point Cloud Shape:", points.shape)  # Expected: (batch_size, max_points, 3)
    print("Batch Labels Shape:", labels.shape)  # Expected: (batch_size, max_points)
    break

Batch Point Cloud Shape: torch.Size([4, 512, 3])
Batch Labels Shape: torch.Size([4, 512])


In [4]:
model = PointNetPPUNet(emb_dim=128, output_dim=64)

optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5
)

loss_history = train_model(model, train_loader, optimizer, num_epochs=50, device='cuda', save_model=True, save_path="model/pointnetpp_unet.pth")

Epoch 1/50: 100%|██████████| 50/50 [00:17<00:00,  2.91it/s, loss=55.8]


✅ Epoch 1: Loss = 57.3797


Epoch 2/50: 100%|██████████| 50/50 [00:14<00:00,  3.56it/s, loss=47.6]


✅ Epoch 2: Loss = 53.0438


Epoch 3/50: 100%|██████████| 50/50 [00:13<00:00,  3.58it/s, loss=41.3]


✅ Epoch 3: Loss = 45.4345


Epoch 4/50: 100%|██████████| 50/50 [00:17<00:00,  2.93it/s, loss=29.9]


✅ Epoch 4: Loss = 38.9736


Epoch 5/50: 100%|██████████| 50/50 [00:18<00:00,  2.76it/s, loss=38]  


✅ Epoch 5: Loss = 35.6244


Epoch 6/50: 100%|██████████| 50/50 [00:14<00:00,  3.50it/s, loss=41.6]


✅ Epoch 6: Loss = 34.0634


Epoch 7/50: 100%|██████████| 50/50 [00:13<00:00,  3.61it/s, loss=27.5]


✅ Epoch 7: Loss = 32.4945


Epoch 8/50: 100%|██████████| 50/50 [00:13<00:00,  3.63it/s, loss=39.5]


✅ Epoch 8: Loss = 31.3878


Epoch 9/50: 100%|██████████| 50/50 [00:14<00:00,  3.52it/s, loss=34.3]


✅ Epoch 9: Loss = 30.4355


Epoch 10/50: 100%|██████████| 50/50 [00:14<00:00,  3.48it/s, loss=17.9]


✅ Epoch 10: Loss = 29.8958


Epoch 11/50: 100%|██████████| 50/50 [00:14<00:00,  3.52it/s, loss=42.4]


✅ Epoch 11: Loss = 29.1445


Epoch 12/50: 100%|██████████| 50/50 [00:15<00:00,  3.16it/s, loss=35.4]


✅ Epoch 12: Loss = 27.7717


Epoch 13/50: 100%|██████████| 50/50 [00:17<00:00,  2.80it/s, loss=29.9]


✅ Epoch 13: Loss = 27.1466


Epoch 14/50: 100%|██████████| 50/50 [00:13<00:00,  3.62it/s, loss=34.8]


✅ Epoch 14: Loss = 26.7708


Epoch 15/50: 100%|██████████| 50/50 [00:13<00:00,  3.60it/s, loss=19.3]


✅ Epoch 15: Loss = 26.2389


Epoch 16/50: 100%|██████████| 50/50 [00:16<00:00,  3.04it/s, loss=28]  


✅ Epoch 16: Loss = 26.2934


Epoch 17/50: 100%|██████████| 50/50 [00:17<00:00,  2.78it/s, loss=22.5]


✅ Epoch 17: Loss = 25.1431


Epoch 18/50: 100%|██████████| 50/50 [00:15<00:00,  3.25it/s, loss=31]  


✅ Epoch 18: Loss = 24.8287


Epoch 19/50: 100%|██████████| 50/50 [00:14<00:00,  3.55it/s, loss=18.3]


✅ Epoch 19: Loss = 24.7025


Epoch 20/50: 100%|██████████| 50/50 [00:14<00:00,  3.55it/s, loss=30.7]


✅ Epoch 20: Loss = 24.0487


Epoch 21/50:  92%|█████████▏| 46/50 [00:13<00:01,  3.39it/s, loss=17.2]


KeyboardInterrupt: 