# Point Transformer - ModelNet10 Classification

This notebook implements the Point Transformer architecture for 3D point cloud classification using the ModelNet10 dataset.

**Workflow:**
1. Download Dataset via `kagglehub`
2. Data Preprocessing (Parse `.off` mesh files to Point Clouds)
3. Model Architecture (Point Transformer)
4. Training & Evaluation

In [None]:
import kagglehub
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import logging
import sys

# 1. Download Dataset
print("Downloading ModelNet10 dataset...")
path = kagglehub.dataset_download("balraj98/modelnet10-princeton-3d-object-dataset")
print("Path to dataset files:", path)

# Fix path if it lands in a subfolder
if os.path.exists(os.path.join(path, 'ModelNet10')):
    DATA_PATH = os.path.join(path, 'ModelNet10')
else:
    DATA_PATH = path
    
print(f"Target Data Path: {DATA_PATH}")

## 2. Data Processing Utils
Since the Kaggle dataset provides `.off` files (meshes), we need a helper to read them and sample points from the vertices.

In [None]:
def read_off(file):
    """
    Reads a .off file and returns vertices.
    """
    if 'OFF' != file.readline().strip():
        raise ValueError('Not a valid OFF header')
    
    n_verts, n_faces, n_dontknow = tuple([int(s) for s in file.readline().strip().split(' ')])
    
    verts = [[float(s) for s in file.readline().strip().split(' ')] for i_vert in range(n_verts)]
    # We ignore faces for this simple point cloud implementation and just take vertices
    # faces = [[int(s) for s in file.readline().strip().split(' ')][1:] for i_face in range(n_faces)]
    
    return np.array(verts)

def pc_normalize(pc):
    """
    Center and scale the point cloud to unit sphere.
    """
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

def farthest_point_sample(point, npoint):
    """
    Input:
        point: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud data, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint, D))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = point[farthest, :]
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    return centroids

In [None]:
class ModelNet10Dataset(Dataset):
    def __init__(self, root, num_point=1024, split='train'):
        self.root = root
        self.num_point = num_point
        self.split = split
        
        self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
        # The kaggle dataset might not have the shape_names.txt, so we infer from folders if missing
        if not os.path.exists(self.catfile):
            self.cat = [d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
            self.cat.sort()
        else:
            self.cat = [line.rstrip() for line in open(self.catfile)]
            
        self.classes = dict(zip(self.cat, range(len(self.cat))))
        
        self.datapath = []
        for shape_name in self.cat:
            # Folder structure: root/shape_name/train/shape_name_0001.off
            shape_dir = os.path.join(self.root, shape_name, split)
            files = glob.glob(os.path.join(shape_dir, '*.off'))
            for f in files:
                self.datapath.append((shape_name, f))
                
        print(f"Loaded {len(self.datapath)} {split} samples.")

    def __len__(self):
        return len(self.datapath)

    def __getitem__(self, index):
        shape_name, file_path = self.datapath[index]
        cls_idx = self.classes[shape_name]
        cls_idx = np.array([cls_idx]).astype(np.int64)
        
        # Read OFF file
        with open(file_path, 'r') as f:
            point_set = read_off(f).astype(np.float32)

        # Sampling
        # If we have fewer points than required, we repeat indices
        if len(point_set) < self.num_point:
            choice = np.random.choice(len(point_set), self.num_point, replace=True)
            point_set = point_set[choice, :]
        else:
            # Use FPS or Random Sampling. FPS is slow for on-the-fly training.
            # We use random sampling for speed in this demo, FPS is better for fixed preprocessing.
            # Switch to FPS if you want higher quality but slower training.
            # point_set = farthest_point_sample(point_set, self.num_point)
            np.random.shuffle(point_set)
            point_set = point_set[:self.num_point, :]

        # Normalize
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

        return point_set, cls_idx

## 3. Point Transformer Architecture
This section contains the core layers: `TransformerBlock`, `TransitionDown`, and the main `PointTransformerCls` model.

In [None]:
def square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def index_points(points, idx):
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def farthest_point_sample_tensor(xyz, npoint):
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def knn_point(nsample, xyz, new_xyz):
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
    return group_idx

class TransformerBlock(nn.Module):
    def __init__(self, d_points, d_model, k) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        self.k = k

    def forward(self, xyz, features):
        dists = square_distance(xyz, xyz)
        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
        knn_xyz = index_points(xyz, knn_idx)
        
        pre = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
        
        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
        
        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
        res = self.fc2(res) + pre
        return res, attn

class PointTransformerCls(nn.Module):
    def __init__(self, num_class, num_point=1024, input_dim=3, k=16):
        super().__init__()
        self.num_point = num_point
        self.k = k
        self.input_dim = input_dim
        
        # Initial embedding
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Blocks and Transitions
        self.transformer1 = TransformerBlock(32, 512, k)
        self.trans1_stride = 4
        self.trans1_fc = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 64))
        
        self.transformer2 = TransformerBlock(64, 512, k)
        self.trans2_stride = 4
        self.trans2_fc = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 128))
        
        self.transformer3 = TransformerBlock(128, 512, k)
        self.trans3_stride = 4
        self.trans3_fc = nn.Sequential(nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 256))
        
        self.transformer4 = TransformerBlock(256, 512, k)
        self.trans4_stride = 4
        self.trans4_fc = nn.Sequential(nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 512))
        
        self.transformer5 = TransformerBlock(512, 512, k)
        
        # Classification Head
        self.fc_layer = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_class)
        )

    def forward(self, x):
        # x is (B, C, N). Transformer expects (B, N, C) for features and (B, N, 3) for xyz
        xyz = x[:, :3, :].permute(0, 2, 1)
        features = x.permute(0, 2, 1)
        if self.input_dim == 3:
            features = features[:, :, :3]
        
        features = self.fc1(features)
        features, _ = self.transformer1(xyz, features)
        
        # Downsample 1
        n_pt1 = xyz.shape[1] // self.trans1_stride
        fps_idx1 = farthest_point_sample_tensor(xyz, n_pt1)
        new_xyz1 = index_points(xyz, fps_idx1)
        knn_idx1 = knn_point(self.k, xyz, new_xyz1)
        grouped_features1 = index_points(features, knn_idx1)
        grouped_features1 = self.trans1_fc(grouped_features1)
        new_features1 = torch.max(grouped_features1, dim=2)[0]
        xyz, features = new_xyz1, new_features1
        
        features, _ = self.transformer2(xyz, features)
        
        # Downsample 2
        n_pt2 = xyz.shape[1] // self.trans2_stride
        fps_idx2 = farthest_point_sample_tensor(xyz, n_pt2)
        new_xyz2 = index_points(xyz, fps_idx2)
        knn_idx2 = knn_point(self.k, xyz, new_xyz2)
        grouped_features2 = index_points(features, knn_idx2)
        grouped_features2 = self.trans2_fc(grouped_features2)
        new_features2 = torch.max(grouped_features2, dim=2)[0]
        xyz, features = new_xyz2, new_features2
        
        features, _ = self.transformer3(xyz, features)
        
        # Downsample 3
        n_pt3 = xyz.shape[1] // self.trans3_stride
        fps_idx3 = farthest_point_sample_tensor(xyz, n_pt3)
        new_xyz3 = index_points(xyz, fps_idx3)
        knn_idx3 = knn_point(self.k, xyz, new_xyz3)
        grouped_features3 = index_points(features, knn_idx3)
        grouped_features3 = self.trans3_fc(grouped_features3)
        new_features3 = torch.max(grouped_features3, dim=2)[0]
        xyz, features = new_xyz3, new_features3
        
        features, _ = self.transformer4(xyz, features)
        
        # Downsample 4
        n_pt4 = xyz.shape[1] // self.trans4_stride
        fps_idx4 = farthest_point_sample_tensor(xyz, n_pt4)
        new_xyz4 = index_points(xyz, fps_idx4)
        knn_idx4 = knn_point(self.k, xyz, new_xyz4)
        grouped_features4 = index_points(features, knn_idx4)
        grouped_features4 = self.trans4_fc(grouped_features4)
        new_features4 = torch.max(grouped_features4, dim=2)[0]
        xyz, features = new_xyz4, new_features4
        
        features, _ = self.transformer5(xyz, features)
        
        # Global Avg Pooling
        features = torch.mean(features, dim=1)
        x = self.fc_layer(features)
        return x

## 4. Training Loop

In [None]:
class Args:
    batch_size = 16
    epoch = 50   # Set to 200 for full convergence
    learning_rate = 0.05
    num_point = 1024
    num_category = 10
    optimizer = 'SGD'
    decay_rate = 1e-4

def train():
    args = Args()
    
    # Check GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Datasets
    train_dataset = ModelNet10Dataset(root=DATA_PATH, num_point=args.num_point, split='train')
    test_dataset = ModelNet10Dataset(root=DATA_PATH, num_point=args.num_point, split='test')
    
    trainDataLoader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, drop_last=True)
    testDataLoader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    
    # Model
    classifier = PointTransformerCls(num_class=args.num_category, num_point=args.num_point, input_dim=3).to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    
    # Optimizer
    optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    
    best_instance_acc = 0.0
    
    print("Start Training...")
    for epoch in range(args.epoch):
        classifier.train()
        mean_correct = []
        
        # Training
        for points, target in tqdm(trainDataLoader, desc=f'Epoch {epoch+1}/{args.epoch}'):
            # Augmentation: Random Rotation
            points = points.numpy()
            points = np.transpose(points, (0, 2, 1)) # B, C, N
            
            theta = np.random.uniform(0, 2 * np.pi, size=points.shape[0])
            rotation_matrix = np.zeros((points.shape[0], 3, 3))
            rotation_matrix[:, 0, 0] = np.cos(theta)
            rotation_matrix[:, 0, 2] = np.sin(theta)
            rotation_matrix[:, 1, 1] = 1
            rotation_matrix[:, 2, 0] = -np.sin(theta)
            rotation_matrix[:, 2, 2] = np.cos(theta)
            
            points[:, :3, :] = np.matmul(rotation_matrix, points[:, :3, :])
            
            points = torch.Tensor(points).to(device)
            target = target[:, 0].long().to(device)
            
            optimizer.zero_grad()
            pred = classifier(points)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            
        scheduler.step()
        train_acc = np.mean(mean_correct)
        
        # Evaluation
        with torch.no_grad():
            classifier.eval()
            total_correct = 0
            total_seen = 0
            for points, target in testDataLoader:
                points = points.permute(0, 2, 1).float().to(device)
                target = target[:, 0].long().to(device)
                
                pred = classifier(points)
                pred_choice = pred.data.max(1)[1]
                total_correct += pred_choice.eq(target.long().data).cpu().sum().item()
                total_seen += points.size()[0]
            
            test_acc = total_correct / float(total_seen)
            
            if test_acc > best_instance_acc:
                best_instance_acc = test_acc
            
            print(f'Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}, Best Test: {best_instance_acc:.4f}')

train()