In [1]:
import torch 
from torch.utils.data import DataLoader
import os

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print("Device: ", device)
%load_ext autoreload
%autoreload 2

Device:  cuda


## Original Cloud Transformers 

### Classification : 


#### ModelNet40 

In [2]:
from data_processing.ModelNet40 import ModelNetDataset
from train_functions.ModelNet40_training import train_modelnet40
from cloud_transformer.classification_model import CT_Classifier
# File paths
root = "data/ModelNet40_PLY"

# Loading datasets + online data augmentation
train_ds = ModelNetDataset(root, folder="train", rotate=True, noise=True)
test_ds  = ModelNetDataset(root, folder="test")

# DataLoaders 
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# input : (B, N, 3)
model_modelnet = CT_Classifier(n_classes=40, model_dim=512, heads=16, num_layers=2, dropout=0.5, use_scales=True , use_checkpoint=False)

load = True 
if load : 
    model_modelnet.load_state_dict(torch.load("saved_models/ct_classifier_modelnet.pth", map_location=device))
    print(f"✅ Model Loaded")
    
model_modelnet.to(device)

best_model = train_modelnet40(model_modelnet, device, train_loader, test_loader, epochs=25, lr=0.001, weight_decay=0, eval_mode=True) # Eval mode to evaluate test_loader once with the loaded model

save = True 
if save : 
    save_path = "saved_models/ct_classifier_modelnet.pth"
    torch.save(best_model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
✅ Model Loaded


                                                                 

[Eval] Val Loss: 0.3523, Acc: 89.75%
✅ Model saved to saved_models/ct_classifier_modelnet.pth




#### ScanObjectNN

In [3]:
from data_processing.ScanObjectNN import ScanObjectNNDataset
from train_functions.ScanObject_training import train_scanobject
from cloud_transformer.classification_model import CT_Classifier

# File paths
train_path = 'data/h5_files/main_split/training_objectdataset_augmentedrot_scale75.h5'
test_path  = 'data/h5_files/main_split/test_objectdataset_augmentedrot_scale75.h5'

# Loading Datasets
train_ds = ScanObjectNNDataset(train_path, rotate=True, noise=True)
test_ds  = ScanObjectNNDataset(test_path)

# DataLoaders — Online Data Augmentation
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=8)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=8)


model_scanobject = CT_Classifier(n_classes=15, model_dim=512, heads=16, num_layers=2, dropout=0.5, use_scales=True , use_checkpoint=True)

load = True 
if load : 
    model_scanobject.load_state_dict(torch.load("saved_models/ct_classifier_scanobject.pth", map_location=device))
    print(f"✅ Model Loaded")

model_scanobject.to(device)
best_model = train_scanobject(model_scanobject, device, train_loader, test_loader, epochs=20, lr=0.001, weight_decay=0, eval_mode=True)

save = True 
if save : 
    save_path = "saved_models/ct_classifier_scanobject.pth"
    torch.save(best_model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

✅ Model Loaded


                                                                 

[Eval] Loss: 0.5367, Cls Acc: 82.65%, Seg Acc: 79.94%
✅ Model saved to saved_models/ct_classifier_scanobject.pth




### Point Completion :

#### ShapeNet

In [None]:
from data_processing.ShapeNet import ShapeNetCompletionDataset
from cloud_transformer.completion_model import CT_Completion
from train_functions.ShapeNet_training import train_shapenet

root = os.path.join(os.getcwd(), "data", "ShapeNetCompletion")
train_ds = ShapeNetCompletionDataset(root, split='train', n_input=1024, n_output=4096)
test_ds  = ShapeNetCompletionDataset(root, split='test', n_input=1024, n_output=4096)

print(f"Train set size: {len(train_ds)}")
print(f"Test set size:  {len(test_ds)}")

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=4)

# Coarser Grid are used here to speed up training
ct_completion_shapenet = CT_Completion(num_latent=512, model_dim=512, heads=8, num_layers=1, use_scales=False, use_checkpoint=False)

load = False 
if load : 
    ct_completion_shapenet.load_state_dict(torch.load("saved_models/ct_completion_shapenet.pth", map_location=device))
    print(f"✅ Model Loaded")

ct_completion_shapenet.to(device)
best_model = train_shapenet(ct_completion_shapenet, device, train_loader, test_loader, epochs=20, lr=0.001, weight_decay=0, subsample=2000, eval_mode=False, chamfer_weight=0.2) 

save = True 
if save : 
    save_path = "saved_models/ct_completion_shapenet.pth"
    torch.save(best_model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Train set size: 135288
Test set size:  600


Initial Eval:  68%|██████▊   | 26/38 [00:14<00:05,  2.03it/s]

#### Visualization : 

In [None]:
from data_processing.ShapeNet import ShapeNetCompletionDataset
from cloud_transformer.completion_model import CT_Completion

root = os.path.join(os.getcwd(), "data", "ShapeNetCompletion")
train_ds = ShapeNetCompletionDataset(root, split='train', n_input=1024, n_output=8192)
test_ds  = ShapeNetCompletionDataset(root, split='test', n_input=1024, n_output=8192)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

model = CT_Completion(num_latent=256, model_dim=256, heads=8, num_layers=1, use_scales=False, use_checkpoint=True)
model.load_state_dict(torch.load("saved_models/ct_completion_shapenet.pth", map_location=device))
print(f"✅ Model Loaded")
model = model.to(device)
from data_processing.ShapeNet import visualize_completion
visualize_completion(model, train_loader, device, num_samples=5)

