In [7]:
import os
import pandas as pd
import torch
import numpy as np
import trimesh
from torch.utils.data import Dataset, DataLoader
from torch.utils.data._utils.collate import default_collate


class MeshDataset(Dataset):
    def __init__(self, csv_file, mesh_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with mesh filenames and labels
            mesh_dir (string): Directory with all the mesh files
            transform (callable, optional): Optional transform for the mesh data
        """
        self.data_frame = pd.read_csv(csv_file, header=None, names=['filename', 'label'])
        self.mesh_dir = mesh_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Get mesh filename and label
        mesh_name = self.data_frame.iloc[idx, 0].strip()
        mesh_path = os.path.join(self.mesh_dir, mesh_name)
        label = self.data_frame.iloc[idx, 1]
        
        # Load the mesh
        mesh = trimesh.load(mesh_path)
        
        # Extract mesh features
        vertices = torch.FloatTensor(mesh.vertices)
        faces = torch.LongTensor(mesh.faces)
        
        sample = {
            'filename': mesh_name,
            'vertices': vertices,
            'faces': faces,
            'label': torch.tensor(label, dtype=torch.long),
            'num_vertices': vertices.shape[0],
            'num_faces': faces.shape[0]
        }
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample


def mesh_collate_fn(batch):
    """
    Custom collate function for meshes with different sizes
    """
    # Get all keys from the first item in the batch
    keys = batch[0].keys()
    
    collated_batch = {}
    
    for key in keys:
        if key in ['vertices', 'faces']:
            # These are variable-sized tensors, so we just keep them as a list
            collated_batch[key] = [item[key] for item in batch]
        else:
            # For other items, use the default collation
            collated_batch[key] = default_collate([item[key] for item in batch])
            
    return collated_batch

In [None]:
# Path to your mesh directory
mesh_dir = "obj_files/"

# Create the dataset
dataset = MeshDataset(csv_file='labels.csv', mesh_dir=mesh_dir)

# Create a dataloader with custom collate function
dataloader = DataLoader(
    dataset, 
    batch_size=4, 
    shuffle=True, 
    num_workers=0,
    collate_fn=mesh_collate_fn
)

# Example of iterating through the dataloader
for idx, batch in enumerate(dataloader):
    print(f"Batch {idx}:")
    print(f"Filenames: {batch['filename']}")
    print(f"Labels: {batch['label']}")
    print(f"Number of vertices per mesh: {batch['num_vertices']}")
    
    # Example of processing variable-sized mesh data
    vertices_list = batch['vertices']
    faces_list = batch['faces']
    
    print(f"First mesh in batch has {vertices_list[0].shape[0]} vertices")

    # Your model forward pass here
    # ...
    
    # Stop after 2 batch for this example
    if idx == 1:
        break

Batch 0:
Filenames: ['chair_066_13.obj', 'chair_018_07.obj', 'chair_015_08.obj', 'chair_001_10.obj']
Labels: tensor([2, 0, 0, 0])
Number of vertices per mesh: tensor([1114,   18,  110, 1416])
First mesh in batch has 1114 vertices
Batch 1:
Filenames: ['chair_019_06.obj', 'chair_001_12.obj', 'chair_011_10.obj', 'chair_124_18.obj']
Labels: tensor([1, 0, 0, 2])
Number of vertices per mesh: tensor([ 798, 1416,    8, 9661])
First mesh in batch has 798 vertices
