In [None]:
from PointUNet import *
from PointNet2FPUNet import *
import pickle

In [None]:
def train_model(model, train_loader, optimizer, num_epochs=10, device='cuda', save_model=True, save_path="pointnetpp_unet.pth"):
    model.to(device)
    model.train()
    loss_history = []

    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for points, labels in progress_bar:
            points, labels = points.to(device), labels.to(device)
            optimizer.zero_grad()

            embeddings = model(points)  # (B, N, output_dim)
            # Compute the discriminative loss
            # Reduce γ (reg term), Use δ_v = 0.3 and δ_d = 1.0
            loss = discriminative_loss(embeddings, labels, delta_v=0.3, delta_d=1.0,
                        alpha=1.0, beta=1.0, gamma=0.0001)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)
        print(f"✅ Epoch {epoch + 1}: Loss = {avg_loss:.4f}")

    if save_model:
        torch.save(model.state_dict(), save_path)
        print("💾 Model saved as "+save_path)

    return loss_history

In [None]:
# Define dataset paths
points_folder = "data/roofNTNU/train_test_split/points_train"
labels_folder = "data/roofNTNU/train_test_split/labels_train"

# Create dataset instance
train_dataset = LiDARPointCloudDataset(points_folder, labels_folder, max_points=512, mode="train")

# Create DataLoader
# train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

train_loader = DataLoader(
    train_dataset, batch_size=4, shuffle=True,
    num_workers=8,  # Use multiple CPU cores for faster loading
    pin_memory=True,  # Optimizes GPU transfers
    collate_fn=collate_fn
)

# Check batch
for points, labels in train_loader:
    print("Batch Point Cloud Shape:", points.shape)  # Expected: (batch_size, max_points, 3)
    print("Batch Labels Shape:", labels.shape)  # Expected: (batch_size, max_points)
    break

In [None]:
model = PointNetPPUNet(emb_dim=128, output_dim=64)

optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-5
)

loss_history = train_model(model, train_loader, optimizer, num_epochs=50, device='cuda', save_model=True, save_path="model/pointnetpp_unet.pth")

In [None]:
def train_model(model, train_loader, optimizer, discriminative_loss, num_epochs=10, device='cuda', log_dir='logs'):
    model.to(device)
    model.train()
    loss_history = []

    os.makedirs(log_dir, exist_ok=True)

    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for i, (points, labels, normals) in enumerate(progress_bar):
            points, labels, normals = points.to(device), labels.to(device), normals.to(device)

            optimizer.zero_grad()
            embeddings = model(points, normals)  # (B, N, D)
            loss = discriminative_loss(embeddings, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)
        print(f"✅ Epoch {epoch + 1} complete. Avg Loss: {avg_loss:.4f}")

        # Save model checkpoint
        checkpoint_path = os.path.join(log_dir, f'model_epoch{epoch + 1}.pth')
        torch.save(model.state_dict(), checkpoint_path)

        # Optional debug visualization using the first batch
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                example_points = points[0].cpu().numpy()
                example_labels = labels[0].cpu().numpy()
                example_normals = normals[0].cpu()
                example_embed = model(points[0:1], example_normals[None, ...]).squeeze(0).cpu().numpy()
                example_embed = normalize_embeddings(example_embed)

                # Clustering
                clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=1.5)
                pred_labels = clustering.fit_predict(example_embed)
                pred_labels = remap_labels(pred_labels)

                # Save visuals
                visualize_embeddings(example_embed, pred_labels,
                                     method='tsne',
                                     save_path=os.path.join(log_dir, f'emb_epoch{epoch + 1}.png'))
                visualize_clusters_matplotlib(example_points, pred_labels,
                                              save_path=os.path.join(log_dir, f'clusters_epoch{epoch + 1}.png'))

                # Print embedding variance
                variance = np.var(example_embed, axis=0).mean()
                print(f"📊 Embedding variance at epoch {epoch + 1}: {variance:.6f}")

            model.train()

    # Save training loss
    with open(os.path.join(log_dir, 'loss_history.pkl'), 'wb') as f:
        pickle.dump(loss_history, f)

    return loss_history

