In [None]:
%cd /ibex/user/slimhy/PADS/code
import numpy as np
import os
import h5py
import torch
from tqdm import tqdm

# Global constants
PARTS_DIR = "/ibex/project/c2273/PADS/3DCoMPaT_occ/parts"
SAMPLES_DIR = "/ibex/project/c2273/PADS/3DCoMPaT_occ/samples"
MAX_PART_DROP = 16

N_POINTS = 131072  # Number of points in each point cloud
RATIO_SUB_POINTS = 1/4.  # Number of points in each sub-point cloud
N_SUB_POINTS = int(N_POINTS * RATIO_SUB_POINTS)

RATIO_SUB_QUERIES = 1/8.
N_SUB_QUERIES = int(N_POINTS * RATIO_SUB_QUERIES)


def load_part_bbs(model_id):
    """
    Load part bounding boxes for a given model ID.
    
    Args:
        model_id (str): The model identifier
    
    Returns:
        dict: Mapping of part keys to vertex arrays [8, 3] representing the 8 corners of each box
    """
    bb_file = os.path.join(SAMPLES_DIR, f"{model_id}_orig_0_bbs.pkl")
    bb_data = np.load(bb_file, allow_pickle=True)
    bb_data = {k:v for k,v in bb_data}
    bb_data = {k:np.array(v.vertices) for k,v in bb_data.items()}
    return bb_data


def load_occs(model_id, part_drop_id=None):
    """
    Load queries and occupancies for a given model ID.
    
    Args:
        model_id (str): The model identifier
        part_drop_id (int, optional): The part drop identifier. If None, loads the original 
                                    version with no parts dropped.
    
    Returns:
        tuple: (queries, occupancies)
            - queries: array of shape [5, N_SUB_POINTS, 3]
            - occupancies: array of shape [5, N_SUB_POINTS]
    """
    if part_drop_id is None:
        occs = os.path.join(SAMPLES_DIR, f"{model_id}_orig_0_occs.npy")
        queries = os.path.join(SAMPLES_DIR, f"{model_id}_orig_0_points.npy")
    else:
        occs = os.path.join(SAMPLES_DIR, f"{model_id}_part_drop_{part_drop_id}_occs.npy")
        queries = os.path.join(SAMPLES_DIR, f"{model_id}_part_drop_{part_drop_id}_points.npy")
    
    occs = np.load(occs)
    queries = np.load(queries)
    
    return queries, occs.reshape(*queries.shape[:2], -1).squeeze()


def load_part_surf_points(model_id):
    """
    Load part surface points for a given model ID.
    
    Args:
        model_id (str): The model identifier
    
    Returns:
        dict: Mapping of part keys to point arrays [N_SUB_POINTS, 3]
    """
    part_file = os.path.join(PARTS_DIR, f"{model_id}.npy")
    part_data = np.load(part_file, allow_pickle=True).item()
    part_data = {k:np.array(v).squeeze() for k,v in part_data.items()}
    return part_data


def get_dropped_part_key(model_id, part_drop_id):
    """
    Get the key for the dropped part by comparing original and dropped configurations.
    
    Args:
        model_id (str): The model identifier
        part_drop_id (int): The part drop identifier
    
    Returns:
        str: Key of the dropped part
    """
    bb_file = os.path.join(SAMPLES_DIR, f"{model_id}_part_drop_{part_drop_id}_bbs.pkl")
    bb_data = np.load(bb_file, allow_pickle=True)
    bb_data = {k:v for k,v in bb_data}
    
    bb_file_orig = os.path.join(SAMPLES_DIR, f"{model_id}_orig_0_bbs.pkl")
    bb_data_orig = np.load(bb_file_orig, allow_pickle=True)
    bb_data_orig = {k:v for k,v in bb_data_orig}
    
    dropped_part_key = set(bb_data_orig.keys()) - set(bb_data.keys())
    assert len(dropped_part_key) == 1, f"Expected exactly one dropped part for {model_id}, drop {part_drop_id}"
    return list(dropped_part_key)[0]


def subsample_points(p, labels=None, max_abs_value=1.0):
    """
    Subsample points using random sampling with a fixed ratio.
    For query points (with labels), first filters points within [-max_abs_value, max_abs_value]^3.
    
    Args:
        p: Points array of shape [N, 3]
        labels: Optional labels array of shape [N]. If provided, indicates query point processing.
        max_abs_value: Maximum absolute value for point coordinates when filtering. Defaults to 1.0.
        
    Returns:
        Subsampled points (and labels if provided)
    """
    p = torch.as_tensor(p)
    
    if labels is not None:
        # Query points - filter to bounded cube first
        mask = torch.all(torch.abs(p) <= max_abs_value, dim=1)
        p = p[mask]
        labels = labels[mask]
        
        # Get exact number of samples
        n_samples = N_SUB_QUERIES
        idx = torch.randperm(len(p))[:n_samples]
        
        p = p[idx]
        labels = labels[idx]
        
        return p.numpy(), labels
    else:
        print(len(p), N_SUB_POINTS)
        # Part points - just random sampling
        idx = torch.randperm(len(p))[:N_SUB_POINTS]
        assert len(idx) == N_SUB_POINTS, f"Invalid subsampling length: {len(idx)}"
        
        return p[idx].numpy()
    
    
def create_stacked_matrices(model_ids):
    """
    Create stacked matrices for part points, bounding boxes, query points, and occupancies.
    
    Args:
        model_ids (list): List of model identifiers to process
    
    Returns:
        dict: Dictionary containing all stacked matrices and metadata arrays
    """
    # First pass: determine total sizes
    total_parts = 0
    total_query_configs = 0
    part_slices = []
    
    print("Calculating matrix dimensions...")
    for model_id in tqdm(model_ids):
        part_data = load_part_surf_points(model_id)
        n_parts = len(part_data)
        part_slices.append(total_parts)
        total_parts += n_parts
        total_query_configs += (MAX_PART_DROP + 1)  # Include original configuration
    
    part_slices.append(total_parts)
    part_slices = np.array(part_slices, dtype=np.int32)
    
    # Initialize matrices with proper data types
    part_points_matrix = np.zeros((total_parts, N_SUB_POINTS, 3), dtype=np.float32)
    part_bbs_matrix = np.zeros((total_parts, 8, 3), dtype=np.float32)
    query_points_matrix = np.zeros((total_query_configs, 5, N_SUB_QUERIES, 3), dtype=np.float32)
    query_labels_matrix = np.zeros((total_query_configs, 5, N_SUB_QUERIES), dtype=np.float32)
    part_drops = np.full((len(model_ids), MAX_PART_DROP), -1, dtype=np.int32)
    model_ids_array = np.array(model_ids, dtype='S')
    
    print("Filling matrices...")
    part_idx = 0
    query_config_idx = 0
    
    for model_idx, model_id in enumerate(tqdm(model_ids)):
        # Load and validate data
        part_points = load_part_surf_points(model_id)
        part_bbs = load_part_bbs(model_id)
        
        # Validate data consistency
        assert set(part_points.keys()) == set(part_bbs.keys()), \
            f"Mismatch in part keys for model {model_id}"
        
        # Fill part matrices
        for part_key in sorted(part_points.keys()):
            points = part_points[part_key]
            bbs = part_bbs[part_key]
            
            # Validate shapes
            assert points.shape == (N_POINTS, 3), \
                f"Invalid point shape for model {model_id}, part {part_key}: {points.shape}"
            # And update the validation in the filling section:
            assert bbs.shape == (8, 3), \
                f"Invalid BB shape for model {model_id}, part {part_key}: {bbs.shape}"
                
            part_points_matrix[part_idx] = subsample_points(points)
            part_bbs_matrix[part_idx] = bbs
            part_idx += 1
        
        # Fill query matrices - first the original configuration
        queries_orig, occs_orig = load_occs(model_id, part_drop_id=None)
        assert queries_orig.shape == (5, N_POINTS, 3), \
            f"Invalid query shape for original config of model {model_id}: {queries_orig.shape}"
        assert occs_orig.shape == (5, N_POINTS), \
            f"Invalid occupancy shape for original config of model {model_id}: {occs_orig.shape}"
        
        for i in range(5):
            max_range = 0.8
            queries_orig_sub, occs_orig_sub = subsample_points(queries_orig[i], occs_orig[i], max_abs_value=max_range)
            query_points_matrix[query_config_idx][i] = queries_orig_sub
            query_labels_matrix[query_config_idx][i] = occs_orig_sub

        query_config_idx += 1
        
        # Then process part drop configurations
        for part_drop_id in range(MAX_PART_DROP):
            try:
                # Record dropped part
                dropped_key = get_dropped_part_key(model_id, part_drop_id)
                dropped_idx = list(sorted(part_points.keys())).index(dropped_key)
                part_drops[model_idx, part_drop_id] = dropped_idx
                
                # Load and validate query data
                queries, occs = load_occs(model_id, part_drop_id)
                assert queries.shape == (5, N_POINTS, 3), \
                    f"Invalid query shape for model {model_id}, drop {part_drop_id}: {queries.shape}"
                assert occs.shape == (5, N_POINTS), \
                    f"Invalid occupancy shape for model {model_id}, drop {part_drop_id}: {occs.shape}"
                    
                for i in range(5):
                    max_range = 0.75 if i <= 3 else 0.5
                    queries_sub, occs_sub = subsample_points(queries[i], occs[i], max_abs_value=max_range)
                    query_points_matrix[query_config_idx][i] = queries_sub
                    query_labels_matrix[query_config_idx][i] = occs_sub
                
                query_config_idx += 1
            except Exception as e:
                print(f"Warning: Failed to process part drop {part_drop_id} "
                      f"for model {model_id}: {str(e)}")
                continue
    
    # Verify final counts match expected values
    assert part_idx == total_parts, \
        f"Mismatch in part count: got {part_idx}, expected {total_parts}"
    assert query_config_idx <= total_query_configs, \
        f"Mismatch in query config count: got {query_config_idx}, expected {total_query_configs}"
    
    return {
        'model_ids': model_ids_array,
        'part_slices': part_slices,
        'part_drops': part_drops,
        'part_points_matrix': part_points_matrix,
        'part_bbs_matrix': part_bbs_matrix,
        'query_points_matrix': query_points_matrix[:query_config_idx],
        'query_labels_matrix': query_labels_matrix[:query_config_idx]
    }

def save_to_hdf5(matrices, output_path):
    """
    Save the stacked matrices to a single HDF5 file.
    
    Args:
        matrices (dict): Dictionary containing matrices to save
        output_path (str): Path where to save the HDF5 file
    """
    print("Saving matrices to HDF5...")
    with h5py.File(output_path, 'w') as f:
        for key, matrix in matrices.items():
            f.create_dataset(key, data=matrix)

In [None]:
def initialize_processed_models(out_path):
    """
    Initialize the set of processed model IDs by including only model IDs
    that appear exactly 68 times in the file listing.

    Args:
        out_path (str): Path to the directory containing the files

    Global Effects:
        Updates the PROCESSED_MODELS global set with qualifying model IDs
    """
    all_files = os.listdir(out_path)

    # Count occurrences of each model ID
    model_counts = {}
    for filename in all_files:
        if len(filename) >= 6:  # Ensure filename is long enough
            model_id = filename[:6]
            model_counts[model_id] = model_counts.get(model_id, 0) + 1

    # Add only model IDs that appear exactly 68 times
    return {
        model_id for model_id, count in model_counts.items() if count == 68
    }

In [None]:
model_ids = initialize_processed_models(SAMPLES_DIR)

error_ids = ["25_41d", "10_01d"]

# Remove models with errors
model_ids = sorted(list(set(model_ids) - set(error_ids)))

In [None]:
# Create stacked matrices
matrices = create_stacked_matrices(list(model_ids)[:100])

# Save to HDF5
output_path = '/ibex/project/c2273/PADS/3DCoMPaT_occ/dataset__debug.h5'
save_to_hdf5(matrices, output_path)
print(f"Dataset created successfully at {output_path}")

In [None]:
# Measure output size in GB
output_size = os.path.getsize(output_path) / 1e9
# Estimate size for full dataset
full_size = output_size * len(model_ids) / 100
print(f"Output size: {output_size:.3f} GB")
print(f"Estimated full size: {full_size:.3f} GB")

In [None]:
# Write a file to confirm that the dataset was created successfully
with open('/ibex/project/c2273/PADS/3DCoMPaT_occ/dataset__debug.txt', 'w') as f:
    f.write("Dataset created successfully!")