![image.png](https://i.imgur.com/a3uAqnb.png)

# 3D Point Cloud Classification with PointNet

This notebook demonstrates how to build a 3D point cloud classification model using **PointNet**, a groundbreaking neural network architecture designed specifically for processing 3D point cloud data.

## 📌 The Core Idea: PointNet Architecture

PointNet revolutionized 3D deep learning by directly processing unordered point sets without requiring voxelization or mesh conversion. The key innovations include:

1. **Permutation Invariance**: The network produces the same output regardless of point order
2. **Spatial Transformation Networks (T-Nets)**: Learn optimal spatial transformations for alignment
3. **Symmetric Aggregation**: Uses max pooling to create a global feature representation
4. **Feature Transform**: Applies learned transformations in feature space for better alignment

### **🎯 Dataset: ModelNet40**
We'll use the ModelNet40 dataset, a standard benchmark for 3D object classification:
- **40 object categories**: chairs, tables, airplanes, cars, etc.
- **12,311 CAD models**: 9,843 for training, 2,468 for testing
- **OFF file format**: Contains 3D vertices and face information

In [None]:
import kagglehub
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
from tqdm.notebook import tqdm
import time

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


## 0️⃣ 3D Data Processing

### 🔹 OFF File Reader
The Object File Format (OFF) stores 3D geometry as vertices and faces. We'll extract only the vertex coordinates for point cloud processing.


In [None]:
def read_off_file(file_path):
    """Read OFF file and return vertices as numpy array."""
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()
        
        # Clean lines and remove comments
        lines = [line.strip() for line in lines if line.strip() and not line.startswith('#')]
        
        if lines[0] != 'OFF':
            return None, None
        
        # Parse header: number of vertices, faces, edges
        n_vertices, n_faces, n_edges = map(int, lines[1].split())
        
        # Extract vertices (x, y, z coordinates)
        vertices = []
        for i in range(2, 2 + n_vertices):
            vertex = list(map(float, lines[i].split()[:3]))
            vertices.append(vertex)
        
        return np.array(vertices), None
    except Exception:
        return None, None

## 1️⃣ Dataset Download and Exploration

First, let's download the ModelNet40 dataset and explore its structure.

In [None]:
# Download dataset and setup paths
path = kagglehub.dataset_download("balraj98/modelnet40-princeton-3d-object-dataset")
dataset_path = Path(path)
modelnet_path = dataset_path / "ModelNet40" 

In [None]:
# Get all class directories
classes = []
for item in modelnet_path.iterdir():
    if item.is_dir() and not item.name.startswith('.'):
        train_subdir = item / "train"
        if train_subdir.exists():
            classes.append(item.name)

classes.sort()
print(f"Found {len(classes)} classes: {classes}")  

## This will take about 4-5 minutes to finish 

In [None]:
# Load dataset samples
def get_all_samples(base_path, classes, split='train'):
    """Load all sample file paths and their corresponding labels, filtering out corrupted files."""
    all_samples = []
    corrupted_count = 0
    
    for class_name in classes:
        class_path = base_path / class_name / split
        if class_path.exists():
            off_files = list(class_path.glob("*.off"))
            for file_path in off_files:
                # Test if the file can be read successfully
                vertices, _ = read_off_file(file_path)
                if vertices is not None:
                    all_samples.append({
                        'file_path': file_path,
                        'class_name': class_name,
                        'class_id': classes.index(class_name)
                    })
                else:
                    corrupted_count += 1
    
    print(f"Filtered out {corrupted_count} corrupted files from {split} set")
    return all_samples

train_samples = get_all_samples(modelnet_path, classes, split='train')
test_samples = get_all_samples(modelnet_path, classes, split='test')
print(f"Train samples: {len(train_samples)}, Test samples: {len(test_samples)}")

A note regarding the above, this will reduce the number of classes to 33 instead of 40. This happens cause some of the files can't be read with the vertices which a PointNet works with. We will still train for 40 classes to make it consistent with the actual dataset.

### 🔹 Data Visualization
Let's visualize examples from different object categories to understand our data better.

In [None]:
def plot_class_examples(samples, n_examples=6):
    """Plot examples from different classes showing full point clouds."""
    # Group samples by class and select one example from each
    class_samples = {}
    for sample in samples:
        class_name = sample['class_name']
        if class_name not in class_samples:
            class_samples[class_name] = []
        class_samples[class_name].append(sample)
    
    selected_samples = []
    for class_name, class_sample_list in list(class_samples.items())[:n_examples]:
        selected_samples.append(class_sample_list[0])
    
    # Create subplots
    n_cols = 3
    n_rows = (len(selected_samples) + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=(15, 5 * n_rows))
    
    for i, sample in enumerate(selected_samples):
        vertices, _ = read_off_file(sample['file_path'])
        
        if vertices is not None:
            ax = fig.add_subplot(n_rows, n_cols, i + 1, projection='3d')
            
            # Plot all vertices with color mapping based on Z-coordinate
            ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], 
                      c=vertices[:, 2], cmap='viridis', s=1, alpha=0.7)
            
            ax.set_title(f"Class: {sample['class_name']}\n({len(vertices)} vertices)", 
                        fontsize=12, pad=20)
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            
            # Set equal aspect ratio for better visualization
            max_range = np.array([vertices[:, 0].max()-vertices[:, 0].min(),
                                vertices[:, 1].max()-vertices[:, 1].min(),
                                vertices[:, 2].max()-vertices[:, 2].min()]).max() / 2.0
            mid_x = (vertices[:, 0].max()+vertices[:, 0].min()) * 0.5
            mid_y = (vertices[:, 1].max()+vertices[:, 1].min()) * 0.5
            mid_z = (vertices[:, 2].max()+vertices[:, 2].min()) * 0.5
            
            ax.set_xlim(mid_x - max_range, mid_x + max_range)
            ax.set_ylim(mid_y - max_range, mid_y + max_range)
            ax.set_zlim(mid_z - max_range, mid_z + max_range)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize examples from different classes
plot_class_examples(train_samples, n_examples=6)

## 2️⃣ Data Preprocessing and Augmentation
Point cloud data requires specialized preprocessing techniques to ensure consistent input format and improve model generalization.

In [None]:
def normalize_point_cloud(points):
    """Center and scale point cloud to unit sphere."""
    centroid = torch.mean(points, dim=0)
    points = points - centroid
    distances = torch.norm(points, dim=1)
    max_distance = torch.max(distances)
    if max_distance > 0:
        points = points / max_distance
    return points

def random_sample_points(points, num_points=1024):
    """Sample fixed number of points from point cloud."""
    num_vertices = points.shape[0]
    if num_vertices >= num_points:
        # Random sampling without replacement
        indices = torch.randperm(num_vertices)[:num_points]
        return points[indices]
    else:
        # Duplicate points if we have fewer than needed
        repeat_factor = (num_points // num_vertices) + 1
        repeated_points = points.repeat(repeat_factor, 1)
        indices = torch.randperm(repeated_points.shape[0])[:num_points]
        return repeated_points[indices]

def random_rotation(points):
    """Apply random rotation around Y-axis for data augmentation."""
    angle = torch.rand(1) * 2 * torch.pi
    cos_angle = torch.cos(angle)
    sin_angle = torch.sin(angle)
    rotation_matrix = torch.tensor([
        [cos_angle, 0, sin_angle],
        [0, 1, 0],
        [-sin_angle, 0, cos_angle]
    ], dtype=points.dtype)
    return torch.matmul(points, rotation_matrix.T)

def add_noise(points, noise_std=0.01):
    """Add Gaussian noise for data augmentation."""
    noise = torch.randn_like(points) * noise_std
    return points + noise

class PointCloudTransform:
    """Comprehensive point cloud preprocessing pipeline."""
    def __init__(self, num_points=1024, normalize=True, augment=True, noise_std=0.01):
        self.num_points = num_points
        self.normalize = normalize
        self.augment = augment
        self.noise_std = noise_std
    
    def __call__(self, points):
        # Sample fixed number of points
        points = random_sample_points(points, self.num_points)
        
        # Normalize to unit sphere
        if self.normalize:
            points = normalize_point_cloud(points)
        
        # Apply augmentations during training
        if self.augment:
            if torch.rand(1) > 0.5:
                points = random_rotation(points)
            if torch.rand(1) > 0.5:
                points = add_noise(points, self.noise_std)
        
        return points

## 3️⃣ PointNet Architecture Implementation

### 🔹 Spatial Transformer Networks (T-Nets)
T-Nets learn optimal spatial transformations to align point clouds, making the model more robust to rotations and translations.

In [None]:
class TNet(nn.Module):
    """Spatial Transformer Network for learning optimal transformations."""
    def __init__(self, k=3):
        super(TNet, self).__init__()
        self.k = k
        
        # Convolutional layers for feature extraction
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        # Fully connected layers for transformation matrix prediction
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        
        # Batch normalization for stable training
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # Feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Global max pooling
        x = torch.max(x, 2)[0]
        
        # Predict transformation matrix
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        # Add identity matrix for stable training
        identity = torch.eye(self.k, device=x.device).view(1, self.k * self.k).repeat(batch_size, 1)
        x = x + identity
        x = x.view(-1, self.k, self.k)
        
        return x

### 🔹 PointNet Encoder
The encoder extracts global features from the point cloud while maintaining permutation invariance.

In [None]:
class PointNetEncoder(nn.Module):
    """PointNet encoder for feature extraction from point clouds."""
    def __init__(self, global_feat=True, feature_transform=True):
        super(PointNetEncoder, self).__init__()
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        
        # Input transformation network
        self.input_transform = TNet(k=3)
        
        # Feature transformation network (optional)
        if self.feature_transform:
            self.feature_transform_net = TNet(k=64)
        
        # Point-wise feature extraction
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        
    def forward(self, x):
        batch_size, num_points, _ = x.size()
        
        # Apply input transformation
        trans_input = self.input_transform(x.transpose(2, 1))
        x = torch.bmm(x, trans_input)
        x = x.transpose(2, 1)
        
        # First feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        
        # Apply feature transformation if enabled
        if self.feature_transform:
            trans_feat = self.feature_transform_net(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2, 1)
        else:
            trans_feat = None
        
        # Continue feature extraction
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        
        # Global feature aggregation via max pooling
        x = torch.max(x, 2)[0]
        
        if self.global_feat:
            return x, trans_input, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, num_points)
            return torch.cat([pointfeat, x], 1), trans_input, trans_feat

### 🔹 PointNet Classifier
The complete classification model combining encoder and classification head.

In [None]:
class PointNetClassifier(nn.Module):
    """Complete PointNet model for 3D point cloud classification."""
    def __init__(self, num_classes=40, dropout=0.3, feature_transform=True):
        super(PointNetClassifier, self).__init__()
        self.feature_transform = feature_transform
        
        # PointNet encoder
        self.encoder = PointNetEncoder(global_feat=True, feature_transform=feature_transform)
        
        # Classification head
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x, trans_input, trans_feat = self.encoder(x)
        
        # Classification layers
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return F.log_softmax(x, dim=1), trans_input, trans_feat

def feature_transform_regularizer(trans):
    """Regularization term for feature transformation matrix."""
    d = trans.size()[1]
    identity = torch.eye(d, device=trans.device)
    identity = identity.unsqueeze(0).repeat(trans.size()[0], 1, 1)
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - identity, dim=(1, 2)))
    return loss

## 4️⃣ Dataset Implementation

In [None]:
class ModelNet40Dataset(Dataset):
    """PyTorch Dataset for ModelNet40 point cloud data."""
    def __init__(self, samples, transform=None, num_points=1024):
        self.samples = samples
        self.transform = transform
        self.num_points = num_points
        self.class_names = sorted(list(set(sample['class_name'] for sample in samples)))
        self.num_classes = len(self.class_names)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        vertices, _ = read_off_file(sample['file_path'])
        
        # Handle corrupted files
        if vertices is None:
            vertices = np.random.randn(self.num_points, 3).astype(np.float32)
        
        vertices = torch.FloatTensor(vertices)
        
        # Apply transformations
        if self.transform:
            vertices = self.transform(vertices)
        else:
            vertices = random_sample_points(vertices, self.num_points)
            vertices = normalize_point_cloud(vertices)
        
        label = torch.LongTensor([sample['class_id']])[0]
        return vertices, label

## 5️⃣ Training Setup and Execution

### 🔹 Hyperparameters and Data Loaders

In [None]:
# Training configuration
NUM_POINTS = 1024
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 15 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data transformations
train_transform = PointCloudTransform(NUM_POINTS, True, True, 0.01)  # With augmentation
test_transform = PointCloudTransform(NUM_POINTS, True, False)        # Without augmentation

# Create datasets and data loaders
train_dataset = ModelNet40Dataset(train_samples, transform=train_transform, num_points=NUM_POINTS)
test_dataset = ModelNet40Dataset(test_samples, transform=test_transform, num_points=NUM_POINTS)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                         pin_memory=True if torch.cuda.is_available() else False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                        pin_memory=True if torch.cuda.is_available() else False)

# Initialize model, loss, and optimizer
model = PointNetClassifier(num_classes=len(classes), dropout=0.3, feature_transform=True).to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
print(f"Number of classes: {len(classes)}")

### 🔹 Training and Evaluation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device, reg_weight=0.001):
    """Train the model for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc='Training')
    
    for batch_idx, (data, target) in enumerate(train_bar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        pred, trans_input, trans_feat = model(data)
        loss = criterion(pred, target)
        
        # Add regularization for feature transformation
        if trans_feat is not None:
            reg_loss = feature_transform_regularizer(trans_feat)
            loss += reg_weight * reg_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        total_loss += loss.item()
        pred_choice = pred.data.max(1)[1]
        correct += pred_choice.eq(target.data).cpu().sum()
        total += target.size(0)
        
        # Update progress bar
        train_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100. * correct / total:.2f}%'
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

def test_epoch(model, test_loader, criterion, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        test_bar = tqdm(test_loader, desc='Testing')
        
        for data, target in test_bar:
            data, target = data.to(device), target.to(device)
            pred, _, _ = model(data)
            loss = criterion(pred, target)
            total_loss += loss.item()
            
            pred_choice = pred.data.max(1)[1]
            correct += pred_choice.eq(target.data).cpu().sum()
            total += target.size(0)
            
            all_preds.extend(pred_choice.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            test_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100. * correct / total:.2f}%'
            })
    
    avg_loss = total_loss / len(test_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy, all_preds, all_targets

### 🔹 Main Training Loop


In [None]:
# Training history
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
best_test_acc = 0
best_model_state = None

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 50)
    
    # Train and evaluate
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc, _, _ = test_epoch(model, test_loader, criterion, device)
    scheduler.step()
    
    # Save best model
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_model_state = model.state_dict().copy()
    
    # Record history
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    print(f'Best Test Acc: {best_test_acc:.2f}%')

total_time = time.time() - start_time
print(f'\nTraining completed in {total_time:.2f} seconds')
print(f'Best test accuracy: {best_test_acc:.2f}%')

## 6️⃣ Results Analysis and Visualization

### 🔹 Training Progress


In [None]:
# Plot training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss', color='blue')
ax1.plot(test_losses, label='Test Loss', color='red')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(train_accuracies, label='Train Accuracy', color='blue')
ax2.plot(test_accuracies, label='Test Accuracy', color='red')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

### 🔹 Final Evaluation with Best Model


In [None]:
# Load best model and perform final evaluation
model.load_state_dict(best_model_state)
final_test_loss, final_test_acc, final_preds, final_targets = test_epoch(model, test_loader, criterion, device)
class_names = train_dataset.class_names

print(f"Final Test Accuracy: {final_test_acc:.2f}%")
print(f"Final Test Loss: {final_test_loss:.4f}")

### 🔹 Confusion Matrix

In [None]:
# Generate and plot confusion matrix
cm = confusion_matrix(final_targets, final_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=False, cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

### 🔹 Per-Class Performance Analysis

In [None]:
# Per-class accuracy analysis
per_class_acc = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(15, 6))
bars = plt.bar(class_names, per_class_acc * 100)
plt.xlabel('Class')
plt.ylabel('Accuracy (%)')
plt.title('Per-Class Accuracy')
plt.xticks(rotation=45, ha='right')
plt.ylim(0, 100)

# Color code bars based on performance
for i, bar in enumerate(bars):
    acc = per_class_acc[i] * 100
    if acc >= 80:
        bar.set_color('green')
    elif acc >= 60:
        bar.set_color('orange')
    else:
        bar.set_color('red')

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print detailed classification report
print("\nDetailed Classification Report:")
print(classification_report(final_targets, final_preds, target_names=class_names, zero_division=0))

## 7️⃣ Model Inference and Prediction Visualization

### 🔹 Single Object Prediction Function

In [None]:
def predict_single_object(model, vertices, transform, class_names, device):
    """Predict class for a single 3D object."""
    model.eval()
    with torch.no_grad():
        if isinstance(vertices, np.ndarray):
            vertices = torch.FloatTensor(vertices)
        
        # Apply preprocessing
        processed_points = transform(vertices)
        batch_points = processed_points.unsqueeze(0).to(device)
        
        # Get prediction
        pred, _, _ = model(batch_points)
        pred_probs = torch.exp(pred)
        top_prob, top_class = torch.max(pred_probs, 1)
        
        predicted_class = class_names[top_class.item()]
        confidence = top_prob.item()
        
        return predicted_class, confidence

### 🔹 Prediction Visualization

In [None]:
def plot_predictions_vs_actual(model, test_samples, transform, class_names, device, n_examples=6):
    """Visualize model predictions on test samples."""
    # Sample valid test examples
    selected_samples = []
    max_attempts = n_examples * 3
    attempts = 0
    
    while len(selected_samples) < n_examples and attempts < max_attempts:
        sample = random.choice(test_samples)
        vertices, _ = read_off_file(sample['file_path'])
        
        if vertices is not None:
            selected_samples.append(sample)
        attempts += 1
    
    if len(selected_samples) == 0:
        print("No valid samples could be loaded for visualization")
        return
    
    actual_n_examples = len(selected_samples)
    n_cols = 3
    n_rows = (actual_n_examples + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=(18, 6 * n_rows))
    correct_predictions = 0
    
    for i, sample in enumerate(selected_samples):
        vertices, _ = read_off_file(sample['file_path'])
        
        # Get prediction
        predicted_class, confidence = predict_single_object(
            model, vertices, transform, class_names, device
        )
        
        actual_class = sample['class_name']
        is_correct = predicted_class == actual_class
        if is_correct:
            correct_predictions += 1
        
        # Plot 3D visualization
        ax = fig.add_subplot(n_rows, n_cols, i + 1, projection='3d')
        ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], 
                  c=vertices[:, 2], cmap='viridis', s=1, alpha=0.7)
        
        # Create title with prediction results
        title_color = 'green' if is_correct else 'red'
        title = f"Actual: {actual_class}\nPredicted: {predicted_class}\nConfidence: {confidence:.3f}"
        
        ax.set_title(title, fontsize=11, pad=20, color=title_color, weight='bold')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # Set equal aspect ratio
        max_range = np.array([vertices[:, 0].max()-vertices[:, 0].min(),
                            vertices[:, 1].max()-vertices[:, 1].min(),
                            vertices[:, 2].max()-vertices[:, 2].min()]).max() / 2.0
        mid_x = (vertices[:, 0].max()+vertices[:, 0].min()) * 0.5
        mid_y = (vertices[:, 1].max()+vertices[:, 1].min()) * 0.5
        mid_z = (vertices[:, 2].max()+vertices[:, 2].min()) * 0.5
        
        ax.set_xlim(mid_x - max_range, mid_x + max_range)
        ax.set_ylim(mid_y - max_range, mid_y + max_range)
        ax.set_zlim(mid_z - max_range, mid_z + max_range)
    
    plt.suptitle(f'Actual vs Predicted Classes ({correct_predictions}/{actual_n_examples} correct)', 
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"Prediction accuracy on sample: {correct_predictions}/{actual_n_examples} = {100*correct_predictions/actual_n_examples:.1f}%")

In [None]:
print(len(class_names))

In [None]:
classes = [
    'airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
    'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
    'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
    'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
    'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'
] # This is a dirty fix to just get the actual classes back even though we only have 33 that we can actually predict.


In [None]:
print("Visualizing predictions vs actual classes...")
plot_predictions_vs_actual(model, test_samples, test_transform, classes, device, n_examples=6)


## 8️⃣ Model Saving and Loading

### 🔹 Save the Trained Model

In [None]:
# Save complete model information
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': train_dataset.num_classes,
    'class_names': class_names,
    'best_test_acc': best_test_acc,
    'num_points': NUM_POINTS,
    'embed_size': 512,
    'hidden_size': 256,
    'num_decoder_layers': 3,
    'training_history': {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'test_losses': test_losses,
        'test_accuracies': test_accuracies
    }
}, 'pointnet_modelnet40.pth')

In [None]:
print(f"Best test accuracy achieved: {best_test_acc:.2f}%")

### 🔹 Load and Use Saved Model

In [None]:
def load_pointnet_model(model_path, device='cpu'):
    """Load a saved PointNet model for inference."""
    checkpoint = torch.load(model_path, map_location=device)
    
    # Recreate model architecture
    model = PointNetClassifier(
        num_classes=checkpoint['num_classes'], 
        dropout=0.3, 
        feature_transform=True
    ).to(device)
    
    # Load trained weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint

In [None]:
# Example usage:
# loaded_model, checkpoint = load_pointnet_model('pointnet_modelnet40.pth', device)
# class_names = checkpoint['class_names']
# print(f"Loaded model with {checkpoint['best_test_acc']:.2f}% test accuracy")


## 9️⃣ Conclusion and Future Improvements


### **📝 Exercises for Further Exploration**

- Experiment with different numbers of points (512, 2048, 4096)
- Try different aggregation functions (mean, attention-based)
- Add more layers to the classification head
- Add more sophisticated augmentations (scaling, jittering)
- Look for a different dataset and see what you get

### **📚 Further Reading**
- Original PointNet paper: "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation"
- PointNet++: "PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space"
- Survey paper: "Deep Learning for 3D Point Clouds: A Survey"

### Contributed by: Ali Habibullah