In [1]:
# --- 1. Import Libraries ---
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from pointRTD import PointRTDModel  # Import the PointRTD model
import os
import random
import numpy as np
from plyfile import PlyData
import trimesh  # For loading .off files as point clouds
import time
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm  # tqdm for Jupyter notebooks

In [2]:
# --- 2. Define Hyperparameters ---
BATCH_SIZE = 64
EPOCHS = 300
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.05
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 40
CORRUPTION_RATIO = 0.6
CHECKPOINT_DIR = f"./checkpoints_modelnet40/PointRTD/CR_{CORRUPTION_RATIO}"
PRETRAINED_CHECKPOINT = f"./checkpoints_pointrtd/pointrtd_epoch_62_CR_{CORRUPTION_RATIO}.pth"
LOG_DIR = "./tensorboard_logs_modelnet10"

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
print("Device: ", DEVICE)

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=LOG_DIR)

Device:  cuda


In [3]:
class ModelNet40Dataset(Dataset):
    def __init__(self, root_dir, split='train', random_split=False, num_points=1024, seed=42, augment=False):
        """
        ModelNet40 dataset class.

        Args:
            root_dir (str): Root directory containing ModelNet40.
            split (str): 'train' or 'test'.
            random_split (bool): If True, use random data split instead of the preset split.
            num_points (int): Number of points to sample from each point cloud.
            seed (int): Random seed for reproducibility.
            augment (bool): Apply data augmentation if True.
        """
        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.augment = augment
        self.data = []
        
        # Set up random split if specified
        if random_split:
            self.random_seed_split(seed)
        else:
            self.preset_split()
    
    def preset_split(self):
        """Use preset train/test split from ModelNet40."""
        classes = sorted(os.listdir(self.root_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        for cls_name in classes:
            class_path = os.path.join(self.root_dir, cls_name, self.split)
            for file_name in os.listdir(class_path):
                if file_name.endswith('.off'):
                    self.data.append((os.path.join(class_path, file_name), self.class_to_idx[cls_name]))
    
    def random_seed_split(self, seed):
        """Create a random split by shuffling files."""
        random.seed(seed)
        classes = sorted(os.listdir(self.root_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        for cls_name in classes:
            class_files = os.listdir(os.path.join(self.root_dir, cls_name))
            random.shuffle(class_files)
            split_idx = int(len(class_files) * 0.8)
            if self.split == 'train':
                self.data.extend([(os.path.join(self.root_dir, cls_name, f), self.class_to_idx[cls_name]) for f in class_files[:split_idx]])
            else:
                self.data.extend([(os.path.join(self.root_dir, cls_name, f), self.class_to_idx[cls_name]) for f in class_files[split_idx:]])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path, class_idx = self.data[idx]
        mesh = trimesh.load(file_path)
        points = mesh.sample(self.num_points)
        points = np.array(points, dtype=np.float32)
        
        # Apply augmentation if specified
        if self.augment and self.split == 'train':
            points = self.apply_augmentations(points)
        
        return torch.tensor(points, dtype=torch.float32), class_idx

    def apply_augmentations(self, points):
        """Random scaling and translation."""
        scale = np.random.uniform(0.8, 1.2)
        points *= scale
        points += np.random.uniform(-0.1, 0.1, size=(1, 3))
        return points
    
# --- 4. Initialize Dataloaders ---
root_dir = './ModelNet40'
train_dataset = ModelNet40Dataset(root_dir, split='train', random_split=False, augment=True)
test_dataset = ModelNet40Dataset(root_dir, split='test', random_split=False, augment=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
# --- 5. Load Pretrained Model and Set Up Classifier ---
# Load pretrained PointMAE model
token_dim = 256
hidden_dim = 256
num_heads = 8
num_layers = 6
num_patches = 64
num_pts_per_patch = 32
num_channels = 3
corruption_ratio = CORRUPTION_RATIO
noise_scale = 1

# Initialize PointRTD Model
pointrtd_model = PointRTDModel(
    input_dim=num_channels,
    token_dim=token_dim,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    corruption_ratio=corruption_ratio,
    noise_scale=noise_scale,
    num_patches=num_patches,
    num_pts_per_patch=num_pts_per_patch,
    finetune=True, # Disable masking 
).to(DEVICE)

if os.path.isfile(PRETRAINED_CHECKPOINT):
    print(f"Loading checkpoint from {PRETRAINED_CHECKPOINT}")
    state_dict = torch.load(PRETRAINED_CHECKPOINT, map_location=DEVICE, weights_only=True)['model_state_dict']
    pointrtd_model.load_state_dict(state_dict)
    print("Checkpoint loaded successfully.")
else:
    print(f"Checkpoint not found at {PRETRAINED_CHECKPOINT}")

encoder = pointrtd_model.encoder

class EncoderWithClassifier(nn.Module):
    def __init__(self, encoder, token_dim=256, num_classes=10):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(token_dim * 3, num_classes, bias=False)  # token_dim * 3 due to concatenation
        
    def forward(self, x):
        # Encode the input to get tokens
        encoded_tokens, *_ = self.encoder(x)  # Shape: (B, num_patches, token_dim)
        
        # Extract CLS token (first token)
        cls_token = encoded_tokens[:, 0, :]  # Shape: (B, token_dim)

        # Mean pooling across all tokens
        mean_pooled = encoded_tokens.mean(dim=1)  # Shape: (B, token_dim)

        # Max pooling across all tokens
        max_pooled, _ = encoded_tokens.max(dim=1)  # Shape: (B, token_dim)

        # Concatenate CLS token, mean-pooled, and max-pooled features
        combined_features = torch.cat([cls_token, mean_pooled, max_pooled], dim=-1)  # Shape: (B, token_dim * 3)

        # Pass through the classifier head
        logits = self.classifier(combined_features)  # Shape: (B, num_classes)
        return logits

classification_model = EncoderWithClassifier(encoder, token_dim=256, num_classes=NUM_CLASSES).to(DEVICE)

Loading checkpoint from ./checkpoints_pointrtd/pointrtd_epoch_62_CR_0.6.pth
Checkpoint loaded successfully.


In [5]:
# --- Calculate class distribution for cross entropy class weights
# --- MAY TAKE A WHILE
# from collections import Counter
# import torch

# # Step 1: Initialize a counter
# class_counts = Counter()

# # Step 2: Count the occurrences of each class in the training dataset
# for _, label in train_dataset:
#     class_counts[label] += 1

# # Step 3: Convert counts to a list
# num_classes = len(class_counts)
# class_count_list = [class_counts[i] for i in range(num_classes)]
class_count_list = [626, 106, 515, 173, 572, 335, 64, 197, 889, 167, 79, 138, 200, 109, 200, 149, 171, 155, 145, 124, 149, 284, 465, 200, 88, 231, 240, 104, 115, 128, 680, 124, 90, 392, 163, 344, 267, 475, 87, 103] # Precomputed

# Step 4: Compute class weights (optional, for use in CrossEntropyLoss)
total_samples = sum(class_count_list)
class_weights = [total_samples / count for count in class_count_list]
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32, device=DEVICE)

print("Class Counts:", class_count_list)
print("Class Weights:", class_weights)


Class Counts: [626, 106, 515, 173, 572, 335, 64, 197, 889, 167, 79, 138, 200, 109, 200, 149, 171, 155, 145, 124, 149, 284, 465, 200, 88, 231, 240, 104, 115, 128, 680, 124, 90, 392, 163, 344, 267, 475, 87, 103]
Class Weights: [15.723642172523961, 92.85849056603773, 19.1126213592233, 56.895953757225435, 17.208041958041957, 29.382089552238806, 153.796875, 49.964467005076145, 11.071991001124859, 58.94011976047904, 124.59493670886076, 71.32608695652173, 49.215, 90.30275229357798, 49.215, 66.06040268456375, 57.56140350877193, 63.50322580645161, 67.88275862068966, 79.37903225806451, 66.06040268456375, 34.65845070422535, 21.16774193548387, 49.215, 111.85227272727273, 42.61038961038961, 41.0125, 94.64423076923077, 85.59130434782608, 76.8984375, 14.475, 79.37903225806451, 109.36666666666666, 25.10969387755102, 60.38650306748466, 28.613372093023255, 36.86516853932584, 20.722105263157896, 113.13793103448276, 95.5631067961165]


In [6]:
# --- 6. Define Optimizer and Scheduler ---
optimizer = optim.AdamW(classification_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=0)

In [7]:
# --- 7. Define Training Loop with tqdm ---
def train_one_epoch(model, loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    progress_bar = tqdm(enumerate(loader), total=len(loader), desc=f"Epoch [{epoch+1}/{EPOCHS}]")
    
    # Initialize CrossEntropy loss
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    
    for batch_idx, (points, labels) in progress_bar:
        points, labels = points.to(DEVICE), labels.to(DEVICE)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(points)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update metrics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update tqdm progress bar
        progress_bar.set_postfix(Batch=f"{batch_idx+1}/{len(loader)}", Loss=f"{loss.item():.4f}")

    avg_loss = running_loss / len(loader)
    accuracy = correct / total
    return avg_loss, accuracy


def validate_one_epoch(model, loader, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Initialize CrossEntropy loss
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    
    with torch.no_grad():
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc=f"Validation Epoch [{epoch+1}/{EPOCHS}]")
        for batch_idx, (points, labels) in progress_bar:
            points, labels = points.to(DEVICE), labels.to(DEVICE)
            outputs = model(points)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update tqdm progress bar
            progress_bar.set_postfix(Batch=f"{batch_idx+1}/{len(loader)}", Loss=f"{loss.item():.4f}")

    avg_loss = running_loss / len(loader)
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
# Initialize a variable to keep track of the lowest validation loss
best_val_loss = float('inf')
best_epoch = -1  # To track the epoch of the best model

# --- 8. Training and Validation Loop ---
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(classification_model, train_loader, optimizer, epoch)
    val_loss, val_acc = validate_one_epoch(classification_model, test_loader, epoch)
    
    # Step the scheduler
    scheduler.step()
    
    # Log metrics to TensorBoard
    writer.add_scalar("Train/Loss", train_loss, epoch)
    writer.add_scalar("Train/Accuracy", train_acc, epoch)
    writer.add_scalar("Validation/Loss", val_loss, epoch)
    writer.add_scalar("Validation/Accuracy", val_acc, epoch)
    writer.add_scalar("Learning_Rate", scheduler.get_last_lr()[0], epoch)

    # Print epoch stats
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"classification_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': classification_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, checkpoint_path)

        print(f"Checkpoint saved at {checkpoint_path}")

    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1  # Update to the current epoch (1-based index)
        best_checkpoint_path = os.path.join(CHECKPOINT_DIR, f"best_model_epoch_{best_epoch}.pth")
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': classification_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, best_checkpoint_path)

        print(f"New best model saved with Val Loss: {val_loss:.4f} at epoch {best_epoch}")

# Close TensorBoard writer
writer.close()
print("Training complete.")


Epoch [1/300]:   0%|          | 0/154 [00:00<?, ?it/s]

  stacked = np.column_stack(stacked).round().astype(np.int64)


Validation Epoch [1/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [1/300] - Train Loss: 3.7254, Train Acc: 0.0617, Val Loss: 3.8695, Val Acc: 0.0543
New best model saved with Val Loss: 3.8695 at epoch 1


Epoch [2/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [2/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [2/300] - Train Loss: 3.3452, Train Acc: 0.1390, Val Loss: 3.4053, Val Acc: 0.1090
New best model saved with Val Loss: 3.4053 at epoch 2


Epoch [3/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [3/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [3/300] - Train Loss: 2.9239, Train Acc: 0.2326, Val Loss: 3.1765, Val Acc: 0.1718
New best model saved with Val Loss: 3.1765 at epoch 3


Epoch [4/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [4/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [4/300] - Train Loss: 2.5840, Train Acc: 0.3116, Val Loss: 2.7125, Val Acc: 0.3035
New best model saved with Val Loss: 2.7125 at epoch 4


Epoch [5/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [5/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [5/300] - Train Loss: 2.3352, Train Acc: 0.3751, Val Loss: 2.2632, Val Acc: 0.3975
New best model saved with Val Loss: 2.2632 at epoch 5


Epoch [6/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [6/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [6/300] - Train Loss: 2.1153, Train Acc: 0.4357, Val Loss: 2.4116, Val Acc: 0.3578


Epoch [7/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [7/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [7/300] - Train Loss: 2.0530, Train Acc: 0.4630, Val Loss: 2.3565, Val Acc: 0.4007


Epoch [8/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [8/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [8/300] - Train Loss: 1.8834, Train Acc: 0.5046, Val Loss: 1.8305, Val Acc: 0.4765
New best model saved with Val Loss: 1.8305 at epoch 8


Epoch [9/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [9/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [9/300] - Train Loss: 1.6472, Train Acc: 0.5628, Val Loss: 1.7747, Val Acc: 0.5142
New best model saved with Val Loss: 1.7747 at epoch 9


Epoch [10/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [10/300]:   0%|          | 0/39 [00:00<?, ?it/s]

Epoch [10/300] - Train Loss: 1.5762, Train Acc: 0.5873, Val Loss: 1.7935, Val Acc: 0.5008
Checkpoint saved at ./checkpoints_modelnet40/PointRTD/CR_0.6/classification_epoch_10.pth


Epoch [11/300]:   0%|          | 0/154 [00:00<?, ?it/s]

Validation Epoch [11/300]:   0%|          | 0/39 [00:00<?, ?it/s]

In [None]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os

# --- Load Classification Model Checkpoint ---
def load_classification_checkpoint(model, checkpoint_path):
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Checkpoint loaded successfully.")
    else:
        print(f"Checkpoint not found at {checkpoint_path}")


# --- Evaluate Model on Test Set with Voting ---
def evaluate_model_with_voting(model, loader, num_votes=10, class_names=None):
    model.eval()
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for points, labels in tqdm(loader, desc="Evaluating Test Set with Voting"):
            points, labels = points.to(DEVICE), labels.to(DEVICE)
            batch_size = points.size(0)
            vote_predictions = []

            for v in range(num_votes):
                # Apply random augmentations to the points
                augmented_points = apply_test_augmentations(points.clone())

                # Forward pass
                outputs = model(augmented_points)
                _, predicted = outputs.max(1)
                vote_predictions.append(predicted.cpu().numpy())

            # Majority voting
            vote_predictions = np.array(vote_predictions)  # Shape: (num_votes, batch_size)
            final_predictions = []

            for i in range(batch_size):
                votes = vote_predictions[:, i]
                most_common = Counter(votes).most_common(1)[0][0]
                final_predictions.append(most_common)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(final_predictions)

    # Calculate accuracy and F1 score
    test_accuracy = accuracy_score(all_labels, all_predictions)
    test_f1_score = f1_score(all_labels, all_predictions, average='weighted')

    # Calculate and optionally display confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_predictions)

    if class_names is not None:
        plt.figure(figsize=(10, 8))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Confusion Matrix')
        plt.show()
    else:
        print("Confusion Matrix:")
        print(conf_matrix)

    return test_accuracy, test_f1_score, conf_matrix

def apply_test_augmentations(points):
    """Apply random rotations and jittering for test-time augmentation."""
    return points

In [None]:
# --- Main Evaluation Loop ---
checkpoint_epochs = [10, 50, 100, 150, 200, 250, 300]
checkpoint_epochs.reverse()
results = []

class_names = [
    'airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 
    'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 
    'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp', 'laptop', 
    'mantel', 'monitor', 'night_stand', 'person', 'piano', 'plant', 
    'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 
    'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'
]

for epoch in checkpoint_epochs:
    checkpoint_path = f"./checkpoints_modelnet10/PointRTD/CR_{CORRUPTION_RATIO}/classification_epoch_{epoch}.pth"
    # checkpoint_path = f"./checkpoints_modelnet10/PointRTD/CR_{CORRUPTION_RATIO}/best_model_epoch_106.pth"
    print(f"\nEvaluating Model at Epoch {epoch}...")
    
    try:
        load_classification_checkpoint(classification_model, checkpoint_path)
        test_accuracy_voting, test_f1_score_voting, conf_matrix = evaluate_model_with_voting(
            classification_model, test_loader, num_votes=10, class_names=class_names
        )
        print(f"Accuracy: {test_accuracy_voting}; F1 Score: {test_f1_score_voting};")
        results.append((epoch, test_accuracy_voting, test_f1_score_voting))
    except FileNotFoundError as e:
        print(e)

# --- Print Results ---
print("\nEvaluation Results:")
print("Epoch\tAccuracy\tF1 Score")
for epoch, acc, f1 in results:
    print(f"{epoch}\t{acc:.4f}\t\t{f1:.4f}")