In [1]:
%cd /home/slimhy/Documents/PADS/code
from datasets.part_occupancies import PartOccupancyDataset, collate
import plotly.graph_objects as go
import numpy as np
import torch


def viz_queries(data_tuple, sample_idx=0):
    """Visualize query points colored by occupancy and bounding box corners"""
    # Get query points and labels
    query_points = data_tuple['query_points'][sample_idx].cpu().numpy()  # [N, 3]
    query_labels = data_tuple['query_labels'][sample_idx].cpu().numpy()  # [N]
    
    # Get bounding box corners
    bbs = data_tuple['part_bbs'][sample_idx].cpu().numpy()  # [P, 8, 3]
    bb_points = bbs.reshape(-1, 3)  # [P*8, 3]
    
    # Split query points by label
    inside_mask = query_labels == 1
    inside_points = query_points[inside_mask]
    outside_points = query_points[~inside_mask]
    
    fig = go.Figure(data=[
        # Inside query points (red)
        go.Scatter3d(
            x=inside_points[:, 0] if len(inside_points) > 0 else [],
            y=inside_points[:, 1] if len(inside_points) > 0 else [],
            z=inside_points[:, 2] if len(inside_points) > 0 else [],
            mode='markers',
            marker=dict(
                size=2,
                color='red',
                opacity=0.8
            ),
            name='Inside Points'
        ),
        # Outside query points (green)
        go.Scatter3d(
            x=outside_points[:, 0] if len(outside_points) > 0 else [],
            y=outside_points[:, 1] if len(outside_points) > 0 else [],
            z=outside_points[:, 2] if len(outside_points) > 0 else [],
            mode='markers',
            marker=dict(
                size=2,
                color='green',
                opacity=0.8
            ),
            name='Outside Points'
        ),
        # Bounding box corners (blue)
        go.Scatter3d(
            x=bb_points[:, 0],
            y=bb_points[:, 1],
            z=bb_points[:, 2],
            mode='markers',
            marker=dict(
                size=8,
                color='blue',
                symbol='square',
                opacity=1.0
            ),
            name='Box Corners'
        )
    ])
    
    fig.update_layout(
        title='Query Points and Box Corners',
        scene=dict(
            aspectmode='data',
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        showlegend=True,
        width=800,
        height=800
    )
    
    return fig

/home/slimhy/Documents/PADS/code


In [2]:
# Create dataset and dataloader
dataset = PartOccupancyDataset(
    hdf5_path="/home/slimhy/Documents/PADS/data/dataset__debug.h5",
    split="train",
    num_queries=2048,
    num_part_points=1024,
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=collate
)

In [25]:
# Example usage:
data_tuple = next(iter(dataloader))

In [26]:
fig = viz_queries(data_tuple)
fig.show()