In [2]:
import torch, platform, numpy as np

print(f"Python : {platform.python_version()}")
print(f"Torch  : {torch.__version__}")
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device :", torch.cuda.get_device_name(0))
    print("Capability  :", torch.cuda.get_device_capability(0))
    # dtype-consistent CUDA op
    a = torch.ones(2,2, device='cuda', dtype=torch.float32)
    b = torch.ones(2,2, device='cuda', dtype=torch.float32)
    two = torch.full((2,2), 2.0, device='cuda', dtype=torch.float32)
    print("Tensor add OK on CUDA:", torch.allclose(a+b, two))


Python : 3.11.14
Torch  : 2.5.1+cu121
CUDA available: True
CUDA device : NVIDIA GeForce GTX 1080
Capability  : (6, 1)
Tensor add OK on CUDA: True


In [4]:
import os, glob, random
import numpy as np
import pandas as pd
from collections import Counter
from plyfile import PlyData

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Repro
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True

# Paths
data_folder = "/mnt/hdd1/desktop/SSS_03/CAPSTONE_SSS_03/Data/train_sphere_ascii_roi"  # <-- put your ROI folder (16 PLYs)
file_list = sorted(glob.glob(os.path.join(data_folder, "*.ply")))
print(f"Found {len(file_list)} PLY files for training!")

# Device / classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 3
CLASS_MAP = {"Background": 0, "Track": 1, "Object": 2}
SCALAR_MAP = {1: "Background", 3: "Track", 9: "Object"}


Found 16 PLY files for training!


In [5]:
class Rail3DDataset(Dataset):
    def __init__(self, file_list, num_points=4096):
        self.file_list = file_list
        self.num_points = num_points
        self.class_map = CLASS_MAP
        self.scalar_map = SCALAR_MAP

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        ply_path = self.file_list[idx]
        ply = PlyData.read(ply_path)
        df = pd.DataFrame(ply['vertex'].data)

        # Points
        if {'x','y','z'}.issubset(df.columns):
            points = df[['x','y','z']].values.astype(np.float32)
        else:
            points = np.zeros((0,3), dtype=np.float32)

        # Labels
        col = next((c for c in ["scalar_NewClassification","scalar_Classification","classification","label"]
                    if c in df.columns), None)
        if col is None:
            # default BG if absent
            raw = pd.Series([1]*len(points))
        else:
            raw = df[col]
        try:
            raw = raw.astype(int)
        except Exception:
            raw = pd.Series([1]*len(points))

        mapped = raw.map(self.scalar_map).fillna("Background")
        labels = mapped.map(self.class_map).astype(np.int64).values

        # Handle empty
        if points.shape[0] == 0:
            pts = np.zeros((self.num_points, 3), dtype=np.float32)
            lbl = np.zeros((self.num_points,), dtype=np.int64)
            return torch.from_numpy(pts), torch.from_numpy(lbl)

        # Per-point inverse-frequency sampling with boosts
        uniq, cnts = np.unique(labels, return_counts=True)
        freq = {int(u): int(c) for u,c in zip(uniq,cnts)}
        inv = np.array([1.0 / max(freq.get(int(l),1),1) for l in labels], dtype=np.float32)

        # Emphasize Track (1) and a bit for Object (2)
        boost = {0:1.0, 1:3.0, 2:1.5}
        inv *= np.vectorize(boost.get)(labels.astype(int))

        alpha = 0.3  # lower alpha = stronger focus on rare classes
        inv_sum = float(inv.sum())
        if not np.isfinite(inv_sum) or inv_sum <= 0:
            inv = np.ones_like(inv) / len(inv)
        p = (1 - alpha) * (inv / inv.sum()) + alpha * (1.0/len(inv))
        p = p / p.sum()

        choice = np.random.choice(len(points), self.num_points, replace=True, p=p)
        points = points[choice]
        labels = labels[choice]

        # Normalize to unit sphere + tiny jitter
        centroid = points.mean(axis=0, keepdims=True)
        points = points - centroid
        max_dist = np.sqrt((points**2).sum(axis=1)).max()
        if max_dist > 0:
            points = points / max_dist
        points = points + np.random.normal(scale=0.001, size=points.shape).astype(np.float32)

        return torch.from_numpy(points), torch.from_numpy(labels)

dataset = Rail3DDataset(file_list, num_points=4096)

# DataLoader: if workers>0 ever crash on your laptop, keep 0
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4,          # try 4 on the GPU box; use 0 if you see worker crashes
    pin_memory=True,
    persistent_workers=True
)


In [6]:
def count_labels_from_files(file_list, scalar_map, class_map):
    cnt = Counter()
    for p in file_list:
        ply = PlyData.read(p)
        df = pd.DataFrame(ply['vertex'].data)
        col = next((c for c in ["scalar_NewClassification","scalar_Classification","classification","label"]
                    if c in df.columns), None)
        if col is None: 
            continue
        raw = df[col].astype(int)
        mapped = raw.map(scalar_map).map(class_map).dropna().astype(int)
        cnt.update(mapped.tolist())
    return cnt

full_cnt = count_labels_from_files(file_list, SCALAR_MAP, CLASS_MAP)
print("Full dataset counts:", dict(full_cnt))

tot = sum(full_cnt.values()) if full_cnt else 1
raw = np.array([tot / max(full_cnt.get(c,1),1) for c in range(num_classes)], dtype=np.float32)
w = np.sqrt(raw); w = w / w.mean(); w = np.clip(w, 0.5, 5.0)
print("Using smoothed class weights:", w)

weights = torch.tensor(w, dtype=torch.float32, device=device)


Full dataset counts: {0: 19983204, 1: 692048, 2: 4431265}
Using smoothed class weights: [0.5        1.8971925  0.74974865]


In [7]:
class STN3d(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):           # x: (B,3,N)
        B = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=False)[0]     # (B,1024)
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = torch.eye(3, device=x.device).view(1, 9).repeat(B, 1)
        x = x + iden
        return x.view(-1, 3, 3)

class STNkd(nn.Module):
    def __init__(self, k=64):
        super().__init__()
        self.k = k
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):           # x: (B,k,N)
        B = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=False)[0]     # (B,1024)
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = torch.eye(self.k, device=x.device).view(1, self.k*self.k).repeat(B, 1)
        x = x + iden
        return x.view(-1, self.k, self.k)

class PointNetSeg(nn.Module):
    def __init__(self, num_classes=3, feature_transform=True):
        super().__init__()
        self.feature_transform = feature_transform
        self.stn = STN3d()
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)

        self.fstn = STNkd(k=64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        # 64 (pointfeat) + 128 (local) + 1024 (global)
        self.conv4 = nn.Conv1d(64+128+1024, 512, 1)
        self.conv5 = nn.Conv1d(512, 256, 1)
        self.conv6 = nn.Conv1d(256, 128, 1)
        self.conv7 = nn.Conv1d(128, num_classes, 1)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
        self.bn6 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):           # x: (B,N,3)
        x = x.transpose(2,1).contiguous()        # (B,3,N)

        trans = self.stn(x)
        x = torch.bmm(trans, x)                  # (B,3,N)

        x = F.relu(self.bn1(self.conv1(x)))      # (B,64,N)
        pointfeat = x

        trans_feat = None
        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = torch.bmm(trans_feat, x)         # (B,64,N)

        local128 = F.relu(self.bn2(self.conv2(x)))   # (B,128,N)
        global1024 = self.bn3(self.conv3(local128))  # (B,1024,N)
        global_pool = torch.max(global1024, 2, keepdim=True)[0]  # (B,1024,1)
        global_tile = global_pool.repeat(1, 1, pointfeat.size(2))# (B,1024,N)

        feat = torch.cat([pointfeat, local128, global_tile], dim=1)  # (B,1216,N)

        x = F.relu(self.bn4(self.conv4(feat)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.dropout(F.relu(self.bn6(self.conv6(x))))
        x = self.conv7(x)                         # (B,C,N)

        return x.transpose(2,1).contiguous(), trans, trans_feat

model = PointNetSeg(num_classes=num_classes, feature_transform=True).to(device)


In [8]:
def feature_transform_regularizer(trans):
    if trans is None:
        return torch.tensor(0.0, device=device)
    B, k, _ = trans.size()
    I = torch.eye(k, device=trans.device).unsqueeze(0).expand(B, -1, -1)
    diff = torch.bmm(trans, trans.transpose(2,1)) - I
    return torch.mean(torch.norm(diff, dim=(1,2)))

# CE with class weights + small FT regularizer
ce_loss = nn.CrossEntropyLoss(weight=weights)

def seg_loss(logits, targets, trans_feat, ft_weight=1e-3):
    ce = ce_loss(logits.view(-1, logits.size(-1)), targets.view(-1))
    reg = feature_transform_regularizer(trans_feat) * ft_weight
    return ce + reg

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())

@torch.no_grad()
def epoch_iou_full(model, dataloader, device, num_classes=3):
    inter = torch.zeros(num_classes, dtype=torch.long)
    union = torch.zeros(num_classes, dtype=torch.long)
    model.eval()
    for P, L in dataloader:
        P = P.to(device); L = L.to(device)
        logits, _, _ = model(P)
        preds = logits.argmax(dim=-1)
        p = preds.view(-1).cpu()
        g = L.view(-1).cpu()
        for c in range(num_classes):
            inter[c] += ((p==c) & (g==c)).sum()
            union[c] += ((p==c) | (g==c)).sum()
    return [(inter[c].item()/union[c].item()) if union[c] > 0 else float('nan')
            for c in range(num_classes)]


In [9]:
epochs = 30
save_dir = "./checkpoints_pointnet"
os.makedirs(save_dir, exist_ok=True)
best_loss = float('inf')

for epoch in range(1, epochs+1):
    model.train()
    run_loss, total_pts, correct = 0.0, 0, 0

    for P, L in dataloader:
        P = P.to(device, non_blocking=True)  # (B,N,3)
        L = L.to(device, non_blocking=True)  # (B,N)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            logits, trans, trans_feat = model(P)          # logits: (B,N,C)
            loss = seg_loss(logits, L, trans_feat, ft_weight=1e-3)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        scaler.step(optimizer)
        scaler.update()

        run_loss += loss.item() * P.size(0)
        preds = logits.argmax(dim=-1)
        correct += (preds == L).sum().item()
        total_pts += L.numel()

    avg_loss = run_loss / len(dataset)
    acc = 100.0 * correct / total_pts

    # full-epoch IoU (uses model.eval() internally)
    ious = epoch_iou_full(model, dataloader, device, num_classes=num_classes)

    print(f"Epoch [{epoch}/{epochs}] Loss: {avg_loss:.4f} | Acc: {acc:.2f}% | IoU: {ious}")
    scheduler.step(avg_loss)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), os.path.join(save_dir, "best_pointnet.pth"))


Epoch [1/30] Loss: 1.2002 | Acc: 33.82% | IoU: [0.0, 0.015365844358866817, 0.2462753461918892]
Epoch [2/30] Loss: 1.1778 | Acc: 34.25% | IoU: [0.0, 0.389404296875, 0.0]
Epoch [3/30] Loss: 1.1852 | Acc: 33.23% | IoU: [0.0, 0.3923492431640625, 0.0]
Epoch [4/30] Loss: 1.1479 | Acc: 34.81% | IoU: [0.0, 0.3924713134765625, 0.0]
Epoch [5/30] Loss: 1.1251 | Acc: 36.01% | IoU: [0.0, 0.3929443359375, 0.0]
Epoch [6/30] Loss: 1.0878 | Acc: 38.84% | IoU: [0.0, 0.3906707763671875, 0.0]
Epoch [7/30] Loss: 1.0844 | Acc: 36.90% | IoU: [0.0, 0.38897705078125, 0.0]
Epoch [8/30] Loss: 1.0551 | Acc: 37.80% | IoU: [0.0, 0.3886260986328125, 0.0]
Epoch [9/30] Loss: 1.0732 | Acc: 36.31% | IoU: [0.0, 0.3912811279296875, 0.0]
Epoch [10/30] Loss: 1.0409 | Acc: 38.46% | IoU: [0.0, 0.3903656005859375, 0.0]
Epoch [11/30] Loss: 1.0302 | Acc: 36.89% | IoU: [0.0, 0.3922882080078125, 0.0]
Epoch [12/30] Loss: 1.0014 | Acc: 39.30% | IoU: [0.0, 0.3901214599609375, 0.0]
Epoch [13/30] Loss: 0.9864 | Acc: 39.63% | IoU: [0.0,