Imports

In [1]:
from torch.utils.data import DataLoader, random_split
from pytorch3d.datasets import ShapeNetCore
import torch
import torch.nn as nn 

Some Metadata

In [2]:
dataset_path = "../Dataset/ShapeNetCore"
train_ratio = 0.8 # 80% training 20% validation
max_points = 700
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


Setup datasets

In [3]:
dataset = ShapeNetCore(dataset_path, version=2, load_textures= False)


CLASS_SIDS = sorted(dataset.synset_dict.keys())          # ['02808440', '02992529', ...]
sid2idx    = {sid: i for i, sid in enumerate(CLASS_SIDS)}  # {'02808440':0, '02992529':1, ...}
NUM_CLASSES = len(CLASS_SIDS)



print(CLASS_SIDS)
print(sid2idx)

train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


['02691156', '02747177', '02773838', '02801938', '02808440', '02818832', '02828884', '02843684', '02871439', '02876657', '02880940', '02924116', '02933112', '02942699', '02946921', '02954340', '02958343', '02992529', '03001627', '03046257', '03085013', '03207941', '03211117', '03261776', '03325088', '03337140', '03467517', '03513137', '03593526', '03624134', '03636649', '03642806', '03691459', '03710193', '03759954', '03761084', '03790512', '03797390', '03928116', '03938244', '03948459', '03991062', '04004475', '04074963', '04090263', '04099429', '04225987', '04256520', '04330267', '04379243', '04401088', '04460130', '04468005', '04530566', '04554684']
{'02691156': 0, '02747177': 1, '02773838': 2, '02801938': 3, '02808440': 4, '02818832': 5, '02828884': 6, '02843684': 7, '02871439': 8, '02876657': 9, '02880940': 10, '02924116': 11, '02933112': 12, '02942699': 13, '02946921': 14, '02954340': 15, '02958343': 16, '02992529': 17, '03001627': 18, '03046257': 19, '03085013': 20, '03207941': 



Custom collate

In [4]:
def custom_collate_fn(batch):
    verts_list = []
    labels = []

    for sample in batch:
        verts = sample['verts']              # [V, 3] (CPU tensor)
        n = verts.shape[0]

        # --- random sampling to exactly max_points ---
        if n >= max_points:
            idx = torch.randperm(n)[:max_points]
        else:
            # sample WITH replacement to reach max_points
            idx = torch.randint(0, n, (max_points,))
        pts = verts[idx]                     # [max_points, 3]

        # (optional but helpful) normalize per shape: center + unit sphere
        pts = pts - pts.mean(0, keepdim=True)
        scale = pts.norm(p=2, dim=1).max()
        pts = pts / (scale + 1e-6)

        verts_list.append(pts)
        # in collate:
        labels.append(sid2idx[sample['synset_id']])  # works if sample['synset_id'] is the ID string

    batched_verts = torch.stack(verts_list, dim=0)        # [B, max_points, 3]
    labels = torch.tensor(labels, dtype=torch.long)       # [B]
    return {'verts': batched_verts, 'labels': labels}

Dataloaders

In [5]:
NUM_WORKERS = 0
PIN_MEMORY = device

train_loader =  DataLoader(train_dataset, batch_size = 4, shuffle = True,  collate_fn = custom_collate_fn, num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY)
val_loader   =  DataLoader(val_dataset  , batch_size = 4, shuffle = False, collate_fn = custom_collate_fn, num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY)

Setup NN

In [6]:
class PointNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.mlp1 = nn.Sequential(
            nn.Conv1d(3,64,1),
            nn.BatchNorm1d(64),
            nn.ReLU())
        
        self.fc1 = nn.Sequential(
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32,num_classes)
        )
    def forward(self, x):
        x = x.transpose(2,1)
        
        
        x = self.mlp1(x)
        
        x = torch.max(x,2)[0]
        
        x = self.fc1(x)
        
        return x

Setup run_epoch()

In [7]:
def run_epoch(model, loader, optimizer=None, device="cpu", epoch_tag="train"):
    is_train = optimizer is not None
    model.train(is_train)

    crit = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch in loader:
        x = batch["verts"].to(device, non_blocking=True)   # (B, 700, 3)
        y = batch["labels"].to(device, non_blocking=True)  # (B,)

        if is_train:
            optimizer.zero_grad()

        logits = model(x)                  # (B, C)
        loss = crit(logits, y)

        if is_train:
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            preds = logits.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            bs = y.size(0)
            total_samples += bs
            total_loss += loss.item() * bs

    avg_loss = total_loss / max(1, total_samples)
    avg_acc  = total_correct / max(1, total_samples)
    print(f"[{epoch_tag}] loss={avg_loss:.4f} | acc={avg_acc:.4f} | n={total_samples}")
    return avg_loss, avg_acc


In [8]:
EPOCHS = 15
lr = 3e-4
weight_decay = 1e-4

model = PointNet(NUM_CLASSES).to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr = lr, weight_decay=weight_decay)

best_val_acc = 0

for epoch in range(1,EPOCHS+1):
    tr_loss, tr_acc   = run_epoch(model, train_loader, optimizer, device, epoch_tag=f"train/ep{epoch}")
    val_loss, val_acc = run_epoch(model, val_loader, optimizer=None, device=device, epoch_tag=f"val/ep{epoch}")

    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_pointnet.pth")
        print(f"saved new best (val_acc={best_val_acc:.4f})")

KeyboardInterrupt: 