In [1]:
# --- 1. Import Libraries ---
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pointMAE import PointMAEModel  # Ensure PointMAE.py is in the same directory
import os
from pytorch3d.loss import chamfer_distance
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

In [2]:
# --- 2. Define Hyperparameters ---
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "./checkpoints_pointmae"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [4]:
synset_to_class = {
    "02691156": "airplane",
    "02747177": "ashcan",
    "02773838": "bag",
    "02801938": "bicycle",
    "02808440": "boat",
    "02818832": "bookcase",
    "02828884": "bus",
    "02843684": "cabinet",
    "02871439": "car",
    "02876657": "cellphone",
    "02880940": "chair",
    "02924116": "cone",
    "02933112": "cup",
    "02942699": "bench",
    "02946921": "gun",
    "02954340": "lamp",
    "02958343": "laptop",
    "02992529": "motorcycle",
    "03001627": "piano",
    "03046257": "rifle",
    "03085013": "rocket",
    "03207941": "skateboard",
    "03211117": "sofa",
    "03261776": "table",
    "03325088": "tower",
    "03337140": "train",
    "03467517": "vehicle",
    "03513137": "display",
    "03593526": "washer",
    "03624134": "clock",
    "03636649": "dishwasher",
    "03642806": "earphone",
    "03691459": "firearm",
    "03710193": "furniture",
    "03759954": "fan",
    "03761084": "hat",
    "03790512": "helmet",
    "03797390": "knife",
    "03928116": "lamp",
    "03938244": "loudspeaker",
    "03948459": "mailbox",
    "03991062": "microphone",
    "04004475": "microwave",
    "04074963": "mug",
    "04099429": "pistol",
    "04225987": "pot",
    "04256520": "printer",
    "04330267": "remote",
    "04379243": "bathtub",
    "04401088": "stove",
    "04460130": "dishwasher",
    "04468005": "telephone",
    "04530566": "watercraft",
    "04554684": "guitar"
}


In [5]:
# --- 3. Initialize Dataset and DataLoader ---
import os
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
import numpy as np
import random
import os
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData
import numpy as np
import random

class ShapeNetPointCloudDataset(Dataset):
    def __init__(self, root_dir, split='train', train_ratio=0.7, val_ratio=0.15, seed=42, augment=False, scale_range=(0.8, 1.2), translation_range=(-0.1, 0.1)):
        """
        Initializes the ShapeNetPointCloudDataset with data split options based on class.

        Args:
            root_dir (str): Path to the ShapeNetCore directory.
            split (str): Which data split to use ('train', 'val', or 'test').
            train_ratio (float): Proportion of data for training.
            val_ratio (float): Proportion of data for validation.
            seed (int): Random seed for reproducibility.
        """
        self.root_dir = root_dir
        self.split = split
        self.data = []  # Stores tuples of (file_path, class_label)
        self.augment = augment
        self.scale_range = scale_range
        self.translation_range = translation_range

        # Collect file paths organized by class
        class_files = self.get_class_files()

        # Generate class_name_to_index mapping
        self.class_name_to_index = {class_name: idx for idx, class_name in enumerate(class_files.keys())}
        
        # Split each class's files into train, val, and test sets
        self.create_splits(class_files, train_ratio, val_ratio, seed)

    def get_class_files(self):
        """
        Collects .ply file paths for each class in the ShapeNet directory.
        Returns:
            dict: A dictionary where keys are class labels, and values are lists of file paths.
        """
        class_files = {synset_to_class[synset]: [] for synset in synset_to_class}
        for synset_id, class_name in synset_to_class.items():
            synset_dir = os.path.join(self.root_dir, synset_id)
            for dirpath, _, files in os.walk(synset_dir):
                for file_name in files:
                    if file_name.endswith('.ply'):
                        class_files[class_name].append(os.path.join(dirpath, file_name))
        return class_files

    def create_splits(self, class_files, train_ratio, val_ratio, seed):
        """
        Splits the dataset into train, val, and test sets based on specified ratios.
        
        Args:
            class_files (dict): Dictionary of class labels to lists of file paths.
            train_ratio (float): Proportion of data for training.
            val_ratio (float): Proportion of data for validation.
            seed (int): Random seed for reproducibility.
        """
        random.seed(seed)
        for class_name, files in class_files.items():
            random.shuffle(files)
            total_files = len(files)
            train_len = int(total_files * train_ratio)
            val_len = int(total_files * val_ratio)

            if self.split == 'train':
                split_files = files[:train_len]
            elif self.split == 'val':
                split_files = files[train_len:train_len + val_len]
            elif self.split == 'test':
                split_files = files[train_len + val_len:]
            else:
                raise ValueError("Invalid split; choose from 'train', 'val', or 'test'.")

            self.data.extend([(file_path, class_name) for file_path in split_files])

    def apply_augmentations(self, points):
        """
        Applies random scaling and translation to the point cloud.
        
        Args:
            points (np.ndarray): Point cloud data, shape (num_points, 3).
        
        Returns:
            np.ndarray: Augmented point cloud.
        """
        # Random scaling
        scale_factor = np.random.uniform(*self.scale_range)
        points *= scale_factor

        # Random translation
        translation = np.random.uniform(*self.translation_range, size=(1, 3))
        points += translation

        return points

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Reads a .ply file, subsamples the point cloud, and returns the point cloud and its class label as a tensor.
        
        Args:
            idx (int): Index of the item in the dataset.
        
        Returns:
            tuple: A tuple containing:
                - points (torch.Tensor): Subsampled point cloud data of shape (num_points, 3).
                - class_index (int): Class index for the point cloud.
        """
        num_points = 1024  # Desired number of points after subsampling
        file_path, class_label = self.data[idx]
        ply_data = PlyData.read(file_path)
        points = np.vstack([
            np.array(ply_data['vertex'][axis]) for axis in ['x', 'y', 'z']
        ]).T  # Shape (N, 3)

        # Subsample points if necessary
        if points.shape[0] > num_points:
            indices = np.random.choice(points.shape[0], num_points, replace=False)
            points = points[indices]
        elif points.shape[0] < num_points:
            # Pad with zeros if there are fewer than num_points
            padding = np.zeros((num_points - points.shape[0], 3))
            points = np.vstack([points, padding])

        # Apply augmentations if enabled and split is 'train'
        if self.augment and self.split == 'train':
            points = self.apply_augmentations(points)

        # Convert class label to class index
        class_index = self.class_name_to_index[class_label]
        
        return torch.tensor(points, dtype=torch.float32), class_index



# Usage

# Create datasets for each split
root_dir = './ShapeNetCore.v2'
train_dataset = ShapeNetPointCloudDataset(root_dir, split='train', augment=True)
val_dataset = ShapeNetPointCloudDataset(root_dir, split='val', augment=False)
test_dataset = ShapeNetPointCloudDataset(root_dir, split='test', augment=False)

# Create DataLoaders for each split
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [6]:
# --- 4. Initialize PointMAE Model, Optimizer, and Scheduler ---
token_dim = 256
num_heads = 8
num_layers = 6
num_patches = 64
num_pts_per_patch = 32
num_channels = 3
mask_ratio = 0.65

# Initialize PointMAE Model
model = PointMAEModel(
    input_dim=num_channels,
    token_dim=token_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    mask_ratio=mask_ratio,
    num_patches=num_patches,
    num_pts_per_patch=num_pts_per_patch,
).to(DEVICE)

# Initialize optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
scheduler = CosineAnnealingLR(optimizer, T_max=300, eta_min=0)

In [7]:
# --- 5. Training Loop ---
for epoch in range(EPOCHS):
    model.train()
    total_reconstruction_loss = 0.0

    for batch_idx, data in enumerate(train_loader):
        pointclouds, _ = data  # Only pointcloud data needed for PointMAE pretraining
        pointclouds = pointclouds.to(DEVICE)
        
        # Forward pass through PointMAE model
        reconstructed_patches, original_patches = model(pointclouds)

        # Calculate Chamfer Distance as reconstruction loss
        reconstruction_loss = model.get_loss(reconstructed_patches, original_patches)

        # Backpropagation
        optimizer.zero_grad()
        reconstruction_loss.backward()
        optimizer.step()
        
        # Accumulate losses for logging
        total_reconstruction_loss += reconstruction_loss.item()

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx}/{len(train_loader)}], "
                f"Reconstruction Loss: {reconstruction_loss.item():.4f}"
            )

    # Step the scheduler after each epoch
    scheduler.step()

    # Log epoch-wise loss
    avg_reconstruction_loss = total_reconstruction_loss / len(train_loader)
    print(
        f"Epoch [{epoch+1}/{EPOCHS}] - Avg Reconstruction Loss: {avg_reconstruction_loss:.4f}"
    )

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"pointmae_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

print("Training complete.")

Epoch [1/100], Batch [0/1096], Reconstruction Loss: 1.8902
Epoch [1/100], Batch [10/1096], Reconstruction Loss: 2.0603
Epoch [1/100], Batch [20/1096], Reconstruction Loss: 0.5427
Epoch [1/100], Batch [30/1096], Reconstruction Loss: 0.2794


KeyboardInterrupt: 