In [None]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset



class PartOccupancyDataset(Dataset):
    """Dataset for part-based occupancy prediction with part dropping."""
    
    def __init__(self, hdf5_path, split='train', num_queries=None, seed=None):
        """
        Initialize the dataset by loading HDF5 matrices into memory.
        
        Args:
            hdf5_path: Path to the HDF5 file containing the dataset
            split: Dataset split ('train', 'val', 'test')
            num_queries: Number of query points to sample (if None, uses all points)
            seed: Random seed for query point sampling
        """
        self.num_queries = num_queries
        if seed is not None:
            np.random.seed(seed)
            
        # Load and validate HDF5 data
        with h5py.File(hdf5_path, 'r') as f:
            # Verify required datasets exist
            required_keys = [
                'model_ids', 'part_slices', 'part_drops',
                'part_points_matrix', 'part_bbs_matrix',
                'query_points_matrix', 'query_labels_matrix'
            ]
            missing_keys = [key for key in required_keys if key not in f]
            if missing_keys:
                raise ValueError(f"Missing required datasets: {missing_keys}")
            
            # Load data into memory
            self.model_ids = f['model_ids'][:].astype('U')
            self.part_slices = f['part_slices'][:]
            self.part_drops = f['part_drops'][:]
            self.part_points = f['part_points_matrix'][:]
            self.part_bbs = f['part_bbs_matrix'][:]
            self.query_points = f['query_points_matrix'][:]
            self.query_labels = f['query_labels_matrix'][:]
            
            # Validate dimensions
            n_models = len(self.model_ids)
            total_configs = self.query_points.shape[0]
            
            expected_configs = n_models * (MAX_PART_DROP + 1)
            assert total_configs >= n_models, "Need at least one config per model"
            assert total_configs <= expected_configs, \
                f"Too many configs: got {total_configs}, expected <= {expected_configs}"
            
            # Verify array shapes
            assert self.part_slices.shape == (n_models + 1,), \
                f"Invalid part_slices shape: {self.part_slices.shape}"
            assert self.part_drops.shape == (n_models, MAX_PART_DROP), \
                f"Invalid part_drops shape: {self.part_drops.shape}"
            assert self.part_points.shape[1:] == (N_SUB_POINTS, 3), \
                f"Invalid part_points shape: {self.part_points.shape}"
            assert self.part_bbs.shape[1:] == (8, 3), \
                f"Invalid part_bbs shape: {self.part_bbs.shape}"
            assert self.query_points.shape[1:] == (5, N_SUB_POINTS, 3), \
                f"Invalid query_points shape: {self.query_points.shape}"
            assert self.query_labels.shape[1:] == (5, N_SUB_POINTS), \
                f"Invalid query_labels shape: {self.query_labels.shape}"

        # Create sample configurations
        self.sample_configs = []
        
        for model_idx, model_id in enumerate(self.model_ids):
            # Get part information
            start_idx = self.part_slices[model_idx]
            end_idx = self.part_slices[model_idx + 1]
            n_parts = end_idx - start_idx
            
            # Calculate query configuration index
            query_config_idx = model_idx * (MAX_PART_DROP + 1)
            
            # Add original configuration (no dropped parts)
            self.sample_configs.append({
                'model_idx': model_idx,
                'query_config_idx': query_config_idx,
                'part_slice': (start_idx, end_idx),
                'dropped_part_idx': None,
                'n_parts': n_parts
            })
            
            # Add part-drop configurations
            for drop_idx in range(MAX_PART_DROP):
                dropped_part_idx = self.part_drops[model_idx, drop_idx]
                if dropped_part_idx != -1:  # Valid part drop
                    self.sample_configs.append({
                        'model_idx': model_idx,
                        'query_config_idx': query_config_idx + drop_idx + 1,
                        'part_slice': (start_idx, end_idx),
                        'dropped_part_idx': dropped_part_idx,
                        'n_parts': n_parts - 1
                    })

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.sample_configs)

    def __getitem__(self, idx):
        """
        Get a sample from the dataset.
        
        Args:
            idx: Index of the sample to retrieve
            
        Returns:
            Dictionary containing:
                - part_points: Part point clouds [N_parts, N_points, 3]
                - part_bbs: Part bounding boxes [N_parts, 8, 3]
                - query_points: Query points [N_queries, 3]
                - query_labels: Occupancy labels [N_queries]
                - model_id: Model identifier
                - n_parts: Number of parts
        """
        config = self.sample_configs[idx]
        
        # Get part data
        start_idx, end_idx = config['part_slice']
        part_points = self.part_points[start_idx:end_idx].copy()
        part_bbs = self.part_bbs[start_idx:end_idx].copy()
        
        # Handle dropped part
        if config['dropped_part_idx'] is not None:
            mask = np.ones(end_idx - start_idx, dtype=bool)
            mask[config['dropped_part_idx']] = False
            part_points = part_points[mask]
            part_bbs = part_bbs[mask]
            
            # Verify part removal
            assert len(part_points) == config['n_parts'], \
                "Mismatch in number of parts after dropping"
        
        # Get query data
        query_points = self.query_points[config['query_config_idx']].copy()
        query_labels = self.query_labels[config['query_config_idx']].copy()
        
        # Sample queries if specified
        if self.num_queries is not None and self.num_queries < N_POINTS:
            indices = np.random.choice(N_POINTS, self.num_queries, replace=False)
            query_points = query_points[indices]
            query_labels = query_labels[indices]
        
        return {
            'part_points': torch.from_numpy(part_points).float(),
            'part_bbs': torch.from_numpy(part_bbs).float(),
            'query_points': torch.from_numpy(query_points).float(),
            'query_labels': torch.from_numpy(query_labels).float(),
            'model_id': self.model_ids[config['model_idx']],
            'n_parts': config['n_parts']
        }
        

def collate_fn(batch):
    """
    Custom collate function for batching samples with variable numbers of parts.
    Pads part-related tensors to match the maximum number of parts in the batch.
    
    Args:
        batch: List of sample dictionaries from the dataset
    
    Returns:
        Dictionary containing batched and padded tensors:
            - part_points: [batch_size, max_parts, N_points, 3]
            - part_bbs: [batch_size, max_parts, 8, 3]
            - query_points: [batch_size, 5, N_queries, 3]
            - query_labels: [batch_size, 5, N_queries]
            - model_ids: List of str, length batch_size
            - n_parts: [batch_size] tensor containing actual number of parts
            - padding_mask: [batch_size, max_parts] boolean tensor
    """
    # Handle empty batch
    if not batch:
        raise ValueError("Empty batch received")
    
    # Validate input tensors
    sample_shapes = {
        'part_points': batch[0]['part_points'].shape[1:],  # [N_points, 3]
        'part_bbs': batch[0]['part_bbs'].shape[1:],       # [8, 3]
        'query_points': batch[0]['query_points'].shape,    # [5, N_queries, 3]
        'query_labels': batch[0]['query_labels'].shape,    # [5, N_queries]
    }
    
    # Verify consistent shapes across batch
    for sample in batch:
        assert sample['part_points'].shape[1:] == sample_shapes['part_points'], \
            "Inconsistent part points shape in batch"
        assert sample['part_bbs'].shape[1:] == sample_shapes['part_bbs'], \
            "Inconsistent bounding boxes shape in batch"
        assert sample['query_points'].shape == sample_shapes['query_points'], \
            "Inconsistent query points shape in batch"
        assert sample['query_labels'].shape == sample_shapes['query_labels'], \
            "Inconsistent query labels shape in batch"
    
    # Get batch information
    batch_size = len(batch)
    max_parts = max(sample['n_parts'] for sample in batch)
    n_points = sample_shapes['part_points'][0]
    n_queries = sample_shapes['query_points'][1]
    
    # Pre-allocate tensors with correct shapes and types
    part_points_batch = torch.zeros(
        batch_size, max_parts, n_points, 3, 
        dtype=batch[0]['part_points'].dtype
    )
    part_bbs_batch = torch.zeros(
        batch_size, max_parts, 8, 3,
        dtype=batch[0]['part_bbs'].dtype
    )
    query_points_batch = torch.zeros(
        batch_size, 5, n_queries, 3,
        dtype=batch[0]['query_points'].dtype
    )
    query_labels_batch = torch.zeros(
        batch_size, 5, n_queries,
        dtype=batch[0]['query_labels'].dtype
    )
    padding_mask = torch.ones(
        batch_size, max_parts,
        dtype=torch.bool
    )
    n_parts = torch.zeros(
        batch_size,
        dtype=torch.long
    )
    
    # Fill tensors
    for i, sample in enumerate(batch):
        n_sample_parts = sample['n_parts']
        
        # Fill part-related tensors up to n_sample_parts
        part_points_batch[i, :n_sample_parts] = sample['part_points']
        part_bbs_batch[i, :n_sample_parts] = sample['part_bbs']
        
        # Update padding mask (False = real part, True = padding)
        padding_mask[i, :n_sample_parts] = False
        
        # Fill query-related tensors (no padding needed)
        query_points_batch[i] = sample['query_points']
        query_labels_batch[i] = sample['query_labels']
        
        # Store number of parts
        n_parts[i] = n_sample_parts
    
    return {
        'part_points': part_points_batch,     # [B, max_parts, N_points, 3]
        'part_bbs': part_bbs_batch,           # [B, max_parts, 8, 3]
        'query_points': query_points_batch,   # [B, 5, N_queries, 3]
        'query_labels': query_labels_batch,   # [B, 5, N_queries]
        'model_ids': [sample['model_id'] for sample in batch],
        'n_parts': n_parts,                   # [B]
        'padding_mask': padding_mask          # [B, max_parts]
    }


In [None]:
from torch.utils.data import DataLoader


dataset = PartOccupancyDataset(output_path)
dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=4,
    collate_fn=collate_fn
)

In [None]:
for batch in dataloader:
    print(batch['part_points'].shape)
    print(batch['part_bbs'].shape)
    print(batch['query_points'].shape)
    print(batch['query_labels'].shape)
    print(batch['n_parts'])
    print(batch['padding_mask'])
    break