In [None]:
import os
import random
import h5py
import json
import torch
import logging
from torch.utils.data import Dataset
import numpy as np
# from .build import DATASETS
# from utils.logger import print

def rotate_point_cloud_z(pc):
    """ Randomly rotate the point clouds to augment the dataset """
    rotation_angle = np.random.uniform() * 2 * np.pi
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    rotation_matrix = np.array([[cosval, -sinval, 0],
                                [sinval, cosval, 0],
                                [0, 0, 1]])
    rotated_data = np.dot(pc, rotation_matrix)
    return rotated_data

def jitter_point_cloud(pc, sigma=0.01, clip=0.05):
    """ Randomly jitter points. jittering is per point. """
    N, C = pc.shape
    assert(clip > 0)
    jittered_data = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
    jittered_data += pc
    return jittered_data

def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25):
    """ Randomly scale the point cloud. Scale is per shape. """
    scale = np.random.uniform(scale_low, scale_high)
    return pc * scale


# @DATASETS.register_module()
class ShapeNetPartH5(Dataset):
    """
    Dataloader for the HDF5 version of ShapeNetPart.
    This is the standard dataset format used in PointNet/PointNet++ and subsequent works.
    """
    # 50-class mapping for ShapeNetPart
    seg_classes = {
        'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
        'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
        'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
        'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
        'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
    }
    
    # Mapping from category name to the class index (0-15)
    classes_map = {
        'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5,
        'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10,
        'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15
    }

    def __init__(self, config):
        self.root = config.DATA_PATH
        self.npoints = config.N_POINTS
        self.split = config.subset
        self.use_augmentation = (self.split == 'train')

        self.all_points = []
        self.all_seg_labels = []
        self.all_cls_labels = []

        # Find all H5 files for the given split (train/test/val)
        h5_files = [f for f in os.listdir(self.root) if f.endswith('.h5') and self.split in f]
        if not h5_files:
            raise FileNotFoundError(f"No H5 files found for split '{self.split}' in '{self.root}'")
        
        print(f"Loading H5 files for '{self.split}' split: {h5_files}")

        for h5_filename in sorted(h5_files):
            f = h5py.File(os.path.join(self.root, h5_filename), 'r')
            points = f['data'][:]
            seg_labels = f['seg'][:] 
            cls_labels = f['label'][:]
            f.close()
            
            self.all_points.append(points)
            self.all_seg_labels.append(seg_labels)
            self.all_cls_labels.append(cls_labels)

        # Concatenate data from all loaded files
        self.all_points = np.concatenate(self.all_points, axis=0)
        self.all_seg_labels = np.concatenate(self.all_seg_labels, axis=0)
        self.all_cls_labels = np.concatenate(self.all_cls_labels, axis=0).squeeze() # Squeeze to make it 1D

        print(f'The size of {self.split} data is {len(self.all_points)}')
        print(f'Number of points per sample: {self.npoints}')
        
        self.classes = self.classes_map
        
    def __len__(self):
        return len(self.all_points)

    def __getitem__(self, index):
        points = self.all_points[index][:self.npoints].copy()
        seg_labels = self.all_seg_labels[index][:self.npoints].copy()
        cls_label = self.all_cls_labels[index].copy()

        # Augmentation is applied only to the training set
        if self.use_augmentation:
            points = rotate_point_cloud_z(points)
            points = jitter_point_cloud(points)
            points = random_scale_point_cloud(points)
        
        # Normalize points
        points = self.pc_normalize(points)


        return (
            torch.from_numpy(points).float(),
            torch.from_numpy(np.array([cls_label])).long(), # Wrap in array for consistent shape
            torch.from_numpy(seg_labels).long()
        )

    @staticmethod
    def pc_normalize(pc):
        centroid = np.mean(pc, axis=0)
        pc = pc - centroid
        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
        pc = pc / (m + 1e-9)
        return pc
    


if __name__ == "__main__":
    import sys
    import random
    # project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    # sys.path.insert(0, project_root)

    class DummyConfig1:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'train'  # can be 'train' or 'test'

    config = DummyConfig1()
    
    print(f"--- Testing ShapeNetPart HDF5 Dataset ---")
    print(f"Loading data from: {config.DATA_PATH}")
    print(f"Subset: {config.subset}, Points per sample: {config.N_POINTS}")
    
    # 2. --- Instantiate the dataset ---
    try:
        dataset_train = ShapeNetPartH5(config)
        
    except Exception as e:
        print(f"\n[ERROR] Failed to initialize dataset. Please check the DATA_PATH in the script.")
        print(f"Details: {e}")
        exit()
    # 3. --- Get and inspect a random sample ---
    if len(dataset_train) == 0:
        print("\n[ERROR] The dataset is empty. No data was found. Check the DATA_PATH and dataset structure.")
    else:
        print(f"\nDataset loaded successfully with {len(dataset_train)} samples.")
        
        # The new dataloader has a map from name to class index. We need the reverse for printing.
        idx_to_name_map = {v: k for k, v in dataset_train.classes_map.items()}
        
        # Get a random item
        random_index = random.randint(0, len(dataset_train) - 1)
        print(f"Fetching random sample at index: {random_index}")
        
        # The __getitem__ method returns a tuple of tensors
        points_tensor, cls_label_tensor, seg_labels_tensor = dataset_train[random_index]   # first is all points and their coordinates second is object category third is label pr points
        
        # --- NEW LOGIC TO GET CATEGORY NAME (continued) ---
        # Get the category index from the tensor, then look up its name
        cat_idx = cls_label_tensor.item()
        cat_name = idx_to_name_map.get(cat_idx, f"UnknownCategory_{cat_idx}")
        
        print(f"\n--- Sample Details for Category: {cat_name} ---")
        
        # Check shapes
        print(f"Points tensor shape:      {points_tensor.shape} (Expected: [{config.N_POINTS}, 3])")
        print(f"Class label tensor shape:   {cls_label_tensor.shape} (Expected: [1])")
        print(f"Seg labels tensor shape:  {seg_labels_tensor.shape} (Expected: [{config.N_POINTS}])")
        
        # Check dtypes
        print(f"\nPoints tensor dtype:      {points_tensor.dtype}")
        print(f"Class label tensor dtype:   {cls_label_tensor.dtype}")
        print(f"Seg labels tensor dtype:  {seg_labels_tensor.dtype}")
        
        # Check content
        class_label = cls_label_tensor.item()
        print(f"\nClass label value: {class_label}")
        
        # This part of the check remains the same and is still the most important one
        unique_labels, counts = np.unique(seg_labels_tensor.numpy(), return_counts=True)
        print("\n--- Segmentation Label Analysis ---")
        print("This is the most crucial check. If you see multiple labels, the loader is working.")
        print(f"Unique part labels found in sample: {unique_labels}")
        print(f"Point counts for each label:      {counts}")

        if len(unique_labels) <= 1:
            print("\n[WARNING] Only one unique label was found. The part segmentation data may not be loading correctly.")
        else:
            print("\n[SUCCESS] Multiple unique labels found. The dataset appears to be loading part data correctly.")


        print(type(dataset_train))

In [None]:
class DummyConfig2:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'test' 



config2 = DummyConfig2()
dataset_test = ShapeNetPartH5(config2)
print(len(dataset_test))



In [None]:
category_to_parts = dataset_train.seg_classes
category_to_parts
category_to_index = dataset_train.classes_map
category_to_index
index_to_category = {v:k for k,v in category_to_index.items()}
index_to_category

In [None]:
class DummyConfig3:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'val' 

config3 = DummyConfig3()
dataset_val = ShapeNetPartH5(config3)
print(len(dataset_val))

In [None]:
import plotly.express as px
import pandas as pd
import numpy as np

def visualize_point_cloud_interactive(sample):
    """
    sample: a tuple from your dataset, e.g., (points, labels) or (points, labels, seg)
    """
    points,classes,labels = sample
    # labels = dataset_train[1] if len(sample) > 1 else np.zeros(points.shape[0])
    
    # Convert to NumPy if torch tensor
    # if not isinstance(points, np.ndarray):
    #     points = points.numpy()
    # if not isinstance(labels, np.ndarray):
    #     labels = labels.numpy()
    
    df = pd.DataFrame({
        "x": points[:, 0],
        "y": points[:, 1],
        "z": points[:, 2],
        "label": labels
    })
    
    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="label",
        labels={"label": "Classes"}, opacity=0.7
    )
    fig.update_traces(marker=dict(size=3, line=dict(width=1, color='DarkSlateGrey')), selector=dict(mode='markers'))
    fig.update_layout(
        title="Interactive Point Cloud Visualization",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        legend_title="Labels"
    )
    fig.show()

# Example usage
visualize_point_cloud_interactive(dataset_train[100])
# visualize_point_cloud_interactive(dataset_train[300])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

####################################
# DGCNN Segmentation Model
####################################
def knn(x, k):
    B, C, N = x.size()
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    _, idx = pairwise_distance.topk(k=k, dim=-1)
    return idx

def get_graph_feature(x, k=20, idx=None):
    B, C, N = x.size()
    if idx is None:
        idx = knn(x, k=k)
    idx_base = torch.arange(0, B, device=x.device).view(-1, 1, 1) * N
    idx = idx + idx_base
    idx = idx.view(-1)
    x = x.transpose(2, 1).contiguous()
    feature = x.view(B * N, -1)[idx, :]
    feature = feature.view(B, N, k, C)
    x = x.view(B, N, 1, C).repeat(1, 1, k, 1)
    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2)
    return feature

class DGCNN_Seg(nn.Module):
    def __init__(self, num_classes=50, num_categories=16, k=20, emb_dims=1024, dropout=0.5):
        super(DGCNN_Seg, self).__init__()
        self.k = k
        self.num_classes = num_classes

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64), nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(128), nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(256), nn.ReLU())
        self.conv5 = nn.Sequential(nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(emb_dims), nn.ReLU())

        # combine global + category features + local
        self.conv6 = nn.Sequential(nn.Conv1d(emb_dims + num_categories, 256, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(256), nn.ReLU())
        self.dp1 = nn.Dropout(p=dropout)
        self.conv7 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(256), nn.ReLU())
        self.dp2 = nn.Dropout(p=dropout)
        self.conv8 = nn.Conv1d(256, num_classes, kernel_size=1, bias=True)

    def forward(self, x, cat_onehot):
        # x: (B, N, 3), cat_onehot: (B, num_categories)
        x = x.permute(0, 2, 1)  # -> (B, 3, N)
        # print(f"x is{x.shape}")
        B, _, N = x.size()

        x1 = self.conv1(get_graph_feature(x, k=self.k)).max(dim=-1)[0]
        # print(f"x1 is{x1.shape}")
        x2 = self.conv2(get_graph_feature(x1, k=self.k)).max(dim=-1)[0]
        # print(f"x2 is{x2.shape}")
        x3 = self.conv3(get_graph_feature(x2, k=self.k)).max(dim=-1)[0]
        # print(f"x3 is{x3.shape}")
        x4 = self.conv4(get_graph_feature(x3, k=self.k)).max(dim=-1)[0]
        # print(f"x4 is{x4.shape}")
        x_cat = torch.cat((x1, x2, x3, x4), dim=1).to(device)
        # print(f"x_cat is{x_cat.shape}")
        x_global = self.conv5(x_cat)  # (B, emb_dims, N)
        # print(f"x_global is:{x_global.shape}")
        # x_max = F.adaptive_max_pool1d(x_global, 1).view(B, -1)  # (B, emb_dims)
        
        # batch_categories shape: [B] or [B,1]
        # num_categories = num_categories.squeeze()       # shape [B]
        # batch_categories = torch.randint(0, 16, (B,))
        # print(f"cat_onehot is:{cat_onehot.shape}")
        # print(cat_onehot)
        cat_onehot = F.one_hot(cat_onehot, num_classes=16).float()  # shape [B,16]
        # print(cat_onehot)
        # print(f"cat_onehot is:{cat_onehot.shape}")
        # print(f"cat_onehot is{cat_onehot.shape}")
        # expand to N points
        # B, N = points.size(0), points.size(1)
        cat_expand = cat_onehot.view(B, 16, 1).repeat(1, 1, N).to(x.device)  # shape [B,16,N]
        # print(f"cat_expand is:{cat_expand.shape}")
        # print(cat_expand)
        
        # cat_expand = cat_onehot.view(B, -1, 1).repeat(1, 1, N)  # (B, num_categories, N)
        # print(f"cat_expand is {cat_expand.shape}")
        
        x = torch.cat((x_global, cat_expand), dim=1)  # (B, 1024+16=1040, N)
        # print(f"x is:{x.shape}")
        x = self.conv6(x)
        x = self.dp1(x)
        x = self.conv7(x)
        x = self.dp2(x)
        x = self.conv8(x)  # (B, num_classes, N)
        x = F.log_softmax(x, dim=1)
        return x

####################################
# Training Loop
####################################
def train_dgcnn_seg(model, train_loader, val_loader, num_classes, num_categories, device, epochs=100):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)

    best_val_acc = 0

    for epoch in range(epochs):
        model.train()
        total_loss, total_correct, total_seen = 0.0, 0, 0
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") as t:
            for points, cat_onehot, seg_labels in t:
                # points: (B,N,3), cat_onehot: (B, num_categories), seg_labels: (B,N)
                points = points.to(device)
                cat_onehot = cat_onehot.to(device)
                seg_labels = seg_labels.to(device).long()

                optimizer.zero_grad()
                preds = model(points, cat_onehot)  # (B,num_classes,N)
                preds = preds.permute(0, 2, 1).contiguous()  # (B,N,num_classes)
                loss = criterion(preds.view(-1, num_classes), seg_labels.view(-1))
                loss.backward()
                optimizer.step()

                pred_choice = preds.argmax(dim=2)
                correct = pred_choice.eq(seg_labels).sum().item()
                total_correct += correct
                total_seen += seg_labels.numel()
                total_loss += loss.item()

                t.set_postfix({
                    "loss": f"{total_loss / len(train_loader):.4f}",
                    "acc": f"{total_correct / total_seen:.4f}"
                })
        scheduler.step()

        # validation
        model.eval()
        val_correct, val_seen = 0, 0
        with torch.no_grad():
            for points, cat_onehot, seg_labels in val_loader:
                points = points.to(device)
                cat_onehot = cat_onehot.to(device)
                seg_labels = seg_labels.to(device).long()
                preds = model(points, cat_onehot)
                preds = preds.permute(0, 2, 1)
                pred_choice = preds.argmax(dim=2)
                val_correct += pred_choice.eq(seg_labels).sum().item()
                val_seen += seg_labels.numel()
        val_acc = val_correct / val_seen
        print(f"Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_dgcnn_seg.pth")

    print(f"âœ… Training complete. Best Val Acc: {best_val_acc:.4f}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = 50
num_categories = 16
model = DGCNN_Seg(num_classes=num_classes, num_categories=num_categories).to(device)

In [None]:
BATCH_SIZE = 8
train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)

train_dgcnn_seg(model, train_loader, val_loader, num_classes, num_categories, device, epochs=1)

In [None]:
# import your DGCNN model class
from model import DGCNN_Seg  # or wherever your class is defined

num_classes = 50  # replace with your dataset's number of segmentation classes
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = DGCNN(num_classes=num_classes).to(device)
checkpoint_path = "/kaggle/working/best_dgcnn_seg.pth"  # path to your .pth file
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# model.eval()  # important for inference



In [None]:
model = DGCNN_Seg(num_classes=num_classes).to(device)
model.load_state_dict(torch.load("/kaggle/working/best_dgcnn_seg.pth", map_location=device))
model.eval()


In [None]:
import torch
import torch.nn.functional as F
import time
import numpy as np
import psutil
import os

BATCH_SIZE = 16
test_loader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

def compute_metrics_dgcnn(model, dataloader, num_classes, device):
    """
    Evaluate DGCNN segmentation model on a dataloader.
    Returns: dict with loss, accuracy, IoU, inference time, memory, model size
    """
    model.eval()
    total_correct, total_seen = 0, 0
    total_loss = 0.0
    all_iou_per_instance = []
    part_intersection = np.zeros(num_classes)
    part_union = np.zeros(num_classes)
    criterion = torch.nn.NLLLoss()  # log_softmax outputs

    start_time = time.time()
    with torch.no_grad():
        for points, category, seg in dataloader:
            points = points.to(device, dtype=torch.float32)         # (B, N, 3)
            category = category.to(device, dtype=torch.long).squeeze()  # (B,)
            seg = seg.to(device, dtype=torch.long)                 # (B, N)

            # -------- Forward pass --------
            preds= model(points, category)  # (B, N, num_classes)
            preds = preds.transpose(1, 2).contiguous()

            # Flatten for loss computation
            preds_flat = preds.view(-1, num_classes)  # (B*N, num_classes)
            seg_flat = seg.view(-1)                   # (B*N,)

            # Compute loss (no feature transform in DGCNN)
            loss = criterion(preds_flat, seg_flat)
            total_loss += loss.item()

            # -------- Predictions --------
            pred_choice = preds_flat.argmax(dim=1)  # (B*N,)
            total_correct += pred_choice.eq(seg_flat).sum().item()
            total_seen += seg_flat.numel()

            # -------- IoU per instance --------
            preds_np = pred_choice.cpu().numpy().reshape(points.size(0), -1)
            seg_np = seg.cpu().numpy().reshape(points.size(0), -1)
            for shape_idx in range(points.size(0)):
                part_iou = []
                for part in np.unique(seg_np[shape_idx]):
                    I = np.sum((preds_np[shape_idx] == part) & (seg_np[shape_idx] == part))
                    U = np.sum((preds_np[shape_idx] == part) | (seg_np[shape_idx] == part))
                    iou = 1.0 if U == 0 else I / float(U)
                    part_iou.append(iou)
                    part_intersection[part] += I
                    part_union[part] += U
                all_iou_per_instance.append(np.mean(part_iou))

    # -------- Metrics --------
    total_time = time.time() - start_time
    num_samples = len(dataloader.dataset)
    avg_inference_time = total_time / num_samples

    overall_acc = total_correct / total_seen

    instance_miou = np.mean(all_iou_per_instance)
    class_miou = np.mean(part_intersection / np.maximum(part_union, 1e-6))

    # -------- Memory usage --------
    process = psutil.Process(os.getpid())
    memory_mb = process.memory_info().rss / 1024 ** 2  # in MB

    # -------- Model size --------
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    model_size_mb = (param_size + buffer_size) / 1024 ** 2

    return {
        'loss': total_loss / len(dataloader),
        'overall_acc': overall_acc,
        'instance_miou': instance_miou,
        'class_miou': class_miou,
        'avg_inference_time': avg_inference_time,
        'memory_mb': memory_mb,
        'model_size_mb': model_size_mb
    }


In [None]:
NUM_CLASSES=50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
metrics = compute_metrics_dgcnn(model, val_loader, NUM_CLASSES, DEVICE)

print(f"\nValidation Metrics:")
print(f"Loss:              {metrics['loss']:.4f}")
print(f"Overall Accuracy:  {metrics['overall_acc']:.4f}")
print(f"Instance mIoU:     {metrics['instance_miou']:.4f}")
print(f"Class mIoU:        {metrics['class_miou']:.4f}")
print(f"Inference Time:    {metrics['avg_inference_time']*1000:.2f} ms per sample")
print(f"Memory Usage:      {metrics['memory_mb']:.2f} MB")
print(f"Model Size:        {metrics['model_size_mb']:.2f} MB\n")


## comparison brooo

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_single_segmentation_with_category(points, seg_gt, seg_pred, category_name="Unknown", shape_idx=0):
    """
    points: (N, 3) numpy array of point cloud coordinates
    seg_gt: (N,) numpy array of ground truth labels
    seg_pred: (N,) numpy array of predicted labels
    category_name: string label for the object category
    shape_idx: integer index of the object
    """

    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        subplot_titles=[
            f"Ground Truth - {category_name} (Object {shape_idx})",
            f"Prediction - {category_name} (Object {shape_idx})"
        ]
    )

    # Ground truth segmentation
    fig.add_trace(go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=seg_gt,
            colorscale='Viridis',
            opacity=0.8,
            colorbar=dict(title="GT Labels")
        ),
        name="Ground Truth"
    ), row=1, col=1)

    # Predicted segmentation
    fig.add_trace(go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=seg_pred,
            colorscale='Rainbow',
            opacity=0.8,
            colorbar=dict(title="Predicted Labels")
        ),
        name="Prediction"
    ), row=1, col=2)

    fig.update_layout(
        title=f"Segmentation Comparison for Category: {category_name}",
        width=1200,
        height=600
    )

    fig.show()


In [None]:
model.eval()
device = DEVICE
num_classes = NUM_CLASSES

# Keep track of which categories we've already visualized
shown_categories = set()

with torch.no_grad():
    for points, category, seg in test_loader:
        points = points.to(device, dtype=torch.float32)              # (B, N, 3)
        category = category.to(device, dtype=torch.long).squeeze()   # (B,)
        seg = seg.to(device, dtype=torch.long)                       # (B, N)

        # -------- Forward pass --------
        preds = model(points, category)  # (B, N, num_classes)

        # Flatten predictions for argmax
        preds_flat = preds.view(-1, num_classes)  # (B*N, num_classes)
        seg_flat = seg.view(-1)                   # (B*N,)
        pred_choice = preds_flat.argmax(dim=1)    # (B*N,)

        # Reshape back to per-point per-sample
        preds_np = pred_choice.cpu().numpy().reshape(points.size(0), -1)  # (B, N)
        seg_np = seg.cpu().numpy().reshape(points.size(0), -1)             # (B, N)

        # -------- Iterate over batch items --------
        for b in range(points.size(0)):
            cat_id = category[b].item()
            cat_name = index_to_category[cat_id]

            # Skip if already visualized this category
            if cat_name in shown_categories:
                continue

            # Plot one sample of this category
            plot_single_segmentation_with_category(
                points[b].cpu().numpy(),  # (N, 3)
                seg_np[b],                # ground truth labels
                preds_np[b],              # predicted labels
                category_name=cat_name,
                shape_idx=b
            )

            shown_categories.add(cat_name)
            print(f"âœ… Shown category: {cat_name} ({len(shown_categories)}/16)")

            # Stop once all 16 categories have been shown
            if len(shown_categories) == 16:
                print("\nðŸŽ‰ Displayed one sample from each of the 16 categories. Done!")
                break

        # Stop outer loop if done
        if len(shown_categories) == 16:
            break
