In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

class ObjectDataset(Dataset):
    """
    A PyTorch Dataset for 3D object data with labels
    """
    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with object paths and labels
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.data_frame = pd.read_csv(csv_file, header=None)
        self.transform = transform
        
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Get object file name and label
        obj_name = self.data_frame.iloc[idx, 0].strip()
        label = self.data_frame.iloc[idx, 1]
        
        # Here you would typically load the 3D object data
        # For example, using a library like trimesh or pytorch3d
        # sample = load_object(obj_name)
        
        # For now, we'll just return the object name and label
        sample = {'obj_name': obj_name, 'label': label}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [3]:
# Create the dataset
dataset = ObjectDataset(csv_file='labels.csv')

# Create a dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

# Example of iterating through the dataloader
for i_batch, sample_batched in enumerate(dataloader):
    print(f"Batch {i_batch}:")
    print(f"Object names: {sample_batched['obj_name']}")
    print(f"Labels: {sample_batched['label']}")
    
    # Stop after 2 batches for this example
    if i_batch == 1:
        break

Batch 0:
Object names: ['chair_019_11.obj', 'chair_001_16.obj', 'chair_018_02.obj', 'chair_018_09.obj']
Labels: tensor([0, 0, 0, 0])
Batch 1:
Object names: ['chair_001_07.obj', 'chair_066_19.obj', 'chair_001_20.obj', 'chair_066_03.obj']
Labels: tensor([0, 2, 0, 0])
