In [1]:
%cd /ibex/user/slimhy/PADS/code
from datasets.part_occupancies import PartOccupancyDataset
from util.viz_occ import viz_queries, viz_part_pointcloud, print_occupancy_stats, convert_to_data_tuple


def convert_dataset_sample_to_viz_format(sample):
    """
    Convert a sample from the PartOccupancyDataset format to the visualization format.
    """
    queries = sample['query_points'].unsqueeze(0)  # [1, N_queries, 3]
    occupancies = sample['query_labels']  # [N_queries] 
    bounding_boxes = sample['part_bbs'].numpy()  # [N_parts, 8, 3]
    
    return convert_to_data_tuple(queries, occupancies, bounding_boxes)


def visualize_dataset_sample(dataset, idx=0, filter_vol=False, filter_surface=False):
    """
    Visualize a sample from the PartOccupancyDataset.
    """
    sample = dataset[idx]
    
    print(f"\nVisualization parameters:")
    print(f"Model ID: {sample['model_id']}")
    print(f"\nData shapes:")
    print(f"Query points: {sample['query_points'].shape}")
    print(f"Query labels: {sample['query_labels'].shape}")
    print(f"Part points: {sample['part_points'].shape}")
    print(f"Part bounding boxes: {sample['part_bbs'].shape}")
    
    dup_sample = {
        'query_points': sample['query_points'].clone(),
        'query_labels': sample['query_labels'].clone(),
        'part_points': sample['part_points'].clone(),
        'part_bbs': sample['part_bbs'].clone()
    }
    
    # Filter to get only volume points (last half of points)
    n_points = sample['query_points'].shape[0] // 2
    if filter_surface:
        dup_sample['query_points'] = sample['query_points'][n_points:]
        dup_sample['query_labels'] = sample['query_labels'][n_points:]
    elif filter_vol:
        dup_sample['query_points'] = sample['query_points'][:n_points]
        dup_sample['query_labels'] = sample['query_labels'][:n_points]

    data = convert_dataset_sample_to_viz_format(dup_sample)
    
    print("\nOccupancy Statistics:")
    print_occupancy_stats(data)
    
    fig = viz_queries(data, unpack_bb=False)
    # fig.update_layout(
    #     scene=dict(
    #         camera=dict(
    #             up=dict(x=0, y=1, z=0),
    #             center=dict(x=0, y=0, z=0),
    #             eye=dict(x=1.5, y=1.5, z=1.5)
    #         )
    #     )
    # )
    
    return fig


def visualize_part(dataset, idx=0, part_idx=0):
    """
    Visualize a specific part's points from a dataset.
    
    Args:
        dataset: PartOccupancyDataset instance
        idx: Index of the sample to visualize
        part_idx: Index of the part to visualize
    
    Returns:
        Plotly figure with part points
    """
    # Get the sample from the dataset
    sample = dataset[idx]
    
    # Get part points
    part_points = sample['part_points'][part_idx]
    
    # Use the plotting function
    fig = viz_part_pointcloud(part_points)
    
    # Update title with more information
    fig.update_layout(
        title=f'Model: {sample["model_id"]}, Part: {part_idx}'
    )
    
    return fig

/ibex/user/slimhy/PADS/code


In [2]:
dataset = PartOccupancyDataset(
    rank=0,
    hdf5_path="/ibex/project/c2273/PADS/3DCoMPaT_occ/dataset__debug.h5",
    num_queries=10000,
    num_part_points=2048,
    random_transform=True,
    rot_angle=15,
    max_scale=1.25,
)

print(f"Dataset size: {len(dataset)}")

Loading HDF5 data to memory...

 Done.
Dataset size: 170


In [6]:
visualize_dataset_sample(dataset, 60, filter_vol=False, filter_surface=True).show()


Visualization parameters:
Model ID: 0c_34b

Data shapes:
Query points: torch.Size([10000, 3])
Query labels: torch.Size([10000])
Part points: torch.Size([11, 2048, 3])
Part bounding boxes: torch.Size([11, 8, 3])

Occupancy Statistics:

Points Statistics:
Total points: 5000
Zeros: 2500 (50.00%)
Ones: 2500 (50.00%)
Ratio of zeros to ones: 1.00
(5000, 3) (5000,)


In [4]:
visualize_part(dataset, 68, ).show()

In [5]:
visualize_part(dataset, 0, 0).show()