In [None]:
import os
import random
import h5py
import json
import torch
import logging
from torch.utils.data import Dataset
import numpy as np
# from .build import DATASETS
# from utils.logger import print

def rotate_point_cloud_z(pc):
    """ Randomly rotate the point clouds to augment the dataset """
    rotation_angle = np.random.uniform() * 2 * np.pi
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    rotation_matrix = np.array([[cosval, -sinval, 0],
                                [sinval, cosval, 0],
                                [0, 0, 1]])
    rotated_data = np.dot(pc, rotation_matrix)
    return rotated_data

def jitter_point_cloud(pc, sigma=0.01, clip=0.05):
    """ Randomly jitter points. jittering is per point. """
    N, C = pc.shape
    assert(clip > 0)
    jittered_data = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
    jittered_data += pc
    return jittered_data

def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25):
    """ Randomly scale the point cloud. Scale is per shape. """
    scale = np.random.uniform(scale_low, scale_high)
    return pc * scale


# @DATASETS.register_module()
class ShapeNetPartH5(Dataset):
    """
    Dataloader for the HDF5 version of ShapeNetPart.
    This is the standard dataset format used in PointNet/PointNet++ and subsequent works.
    """
    # 50-class mapping for ShapeNetPart
    seg_classes = {
        'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
        'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
        'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
        'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
        'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
    }
    
    # Mapping from category name to the class index (0-15)
    classes_map = {
        'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5,
        'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10,
        'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15
    }

    def __init__(self, config):
        self.root = config.DATA_PATH
        self.npoints = config.N_POINTS
        self.split = config.subset
        self.use_augmentation = (self.split == 'train')

        self.all_points = []
        self.all_seg_labels = []
        self.all_cls_labels = []

        # Find all H5 files for the given split (train/test/val)
        h5_files = [f for f in os.listdir(self.root) if f.endswith('.h5') and self.split in f]
        if not h5_files:
            raise FileNotFoundError(f"No H5 files found for split '{self.split}' in '{self.root}'")
        
        print(f"Loading H5 files for '{self.split}' split: {h5_files}")

        for h5_filename in sorted(h5_files):
            f = h5py.File(os.path.join(self.root, h5_filename), 'r')
            points = f['data'][:]
            seg_labels = f['seg'][:] 
            cls_labels = f['label'][:]
            f.close()
            
            self.all_points.append(points)
            self.all_seg_labels.append(seg_labels)
            self.all_cls_labels.append(cls_labels)

        # Concatenate data from all loaded files
        self.all_points = np.concatenate(self.all_points, axis=0)
        self.all_seg_labels = np.concatenate(self.all_seg_labels, axis=0)
        self.all_cls_labels = np.concatenate(self.all_cls_labels, axis=0).squeeze() # Squeeze to make it 1D

        print(f'The size of {self.split} data is {len(self.all_points)}')
        print(f'Number of points per sample: {self.npoints}')
        
        self.classes = self.classes_map
        
    def __len__(self):
        return len(self.all_points)

    def __getitem__(self, index):
        points = self.all_points[index][:self.npoints].copy()
        seg_labels = self.all_seg_labels[index][:self.npoints].copy()
        cls_label = self.all_cls_labels[index].copy()

        # Augmentation is applied only to the training set
        if self.use_augmentation:
            points = rotate_point_cloud_z(points)
            points = jitter_point_cloud(points)
            points = random_scale_point_cloud(points)
        
        # Normalize points
        points = self.pc_normalize(points)


        return (
            torch.from_numpy(points).float(),
            torch.from_numpy(np.array([cls_label])).long(), # Wrap in array for consistent shape
            torch.from_numpy(seg_labels).long()
        )

    @staticmethod
    def pc_normalize(pc):
        centroid = np.mean(pc, axis=0)
        pc = pc - centroid
        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
        pc = pc / (m + 1e-9)
        return pc
    


if __name__ == "__main__":
    import sys
    import random
    # project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    # sys.path.insert(0, project_root)

    class DummyConfig1:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'train'  # can be 'train' or 'test'

    config = DummyConfig1()
    
    print(f"--- Testing ShapeNetPart HDF5 Dataset ---")
    print(f"Loading data from: {config.DATA_PATH}")
    print(f"Subset: {config.subset}, Points per sample: {config.N_POINTS}")
    
    # 2. --- Instantiate the dataset ---
    try:
        dataset_train = ShapeNetPartH5(config)
        
    except Exception as e:
        print(f"\n[ERROR] Failed to initialize dataset. Please check the DATA_PATH in the script.")
        print(f"Details: {e}")
        exit()
    # 3. --- Get and inspect a random sample ---
    if len(dataset_train) == 0:
        print("\n[ERROR] The dataset is empty. No data was found. Check the DATA_PATH and dataset structure.")
    else:
        print(f"\nDataset loaded successfully with {len(dataset_train)} samples.")
        
        # The new dataloader has a map from name to class index. We need the reverse for printing.
        idx_to_name_map = {v: k for k, v in dataset_train.classes_map.items()}
        
        # Get a random item
        random_index = random.randint(0, len(dataset_train) - 1)
        print(f"Fetching random sample at index: {random_index}")
        
        # The __getitem__ method returns a tuple of tensors
        points_tensor, cls_label_tensor, seg_labels_tensor = dataset_train[random_index]   # first is all points and their coordinates second is object category third is label pr points
        
        # --- NEW LOGIC TO GET CATEGORY NAME (continued) ---
        # Get the category index from the tensor, then look up its name
        cat_idx = cls_label_tensor.item()
        cat_name = idx_to_name_map.get(cat_idx, f"UnknownCategory_{cat_idx}")
        
        print(f"\n--- Sample Details for Category: {cat_name} ---")
        
        # Check shapes
        print(f"Points tensor shape:      {points_tensor.shape} (Expected: [{config.N_POINTS}, 3])")
        print(f"Class label tensor shape:   {cls_label_tensor.shape} (Expected: [1])")
        print(f"Seg labels tensor shape:  {seg_labels_tensor.shape} (Expected: [{config.N_POINTS}])")
        
        # Check dtypes
        print(f"\nPoints tensor dtype:      {points_tensor.dtype}")
        print(f"Class label tensor dtype:   {cls_label_tensor.dtype}")
        print(f"Seg labels tensor dtype:  {seg_labels_tensor.dtype}")
        
        # Check content
        class_label = cls_label_tensor.item()
        print(f"\nClass label value: {class_label}")
        
        # This part of the check remains the same and is still the most important one
        unique_labels, counts = np.unique(seg_labels_tensor.numpy(), return_counts=True)
        print("\n--- Segmentation Label Analysis ---")
        print("This is the most crucial check. If you see multiple labels, the loader is working.")
        print(f"Unique part labels found in sample: {unique_labels}")
        print(f"Point counts for each label:      {counts}")

        if len(unique_labels) <= 1:
            print("\n[WARNING] Only one unique label was found. The part segmentation data may not be loading correctly.")
        else:
            print("\n[SUCCESS] Multiple unique labels found. The dataset appears to be loading part data correctly.")


        print(type(dataset_train))

In [None]:
class DummyConfig2:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'test' 



config2 = DummyConfig2()
dataset_test = ShapeNetPartH5(config2)
print(len(dataset_test))



In [None]:
category_to_parts = dataset_train.seg_classes
category_to_parts
category_to_index = dataset_train.classes_map
category_to_index
index_to_category = {v:k for k,v in category_to_index.items()}
index_to_category

In [None]:
class DummyConfig3:
        def __init__(self):
            # --- IMPORTANT ---
            # This path must point to the folder containing the .h5 files
            self.DATA_PATH = '/kaggle/input/shapenetpart/shapenetpart_hdf5_2048'
            self.N_POINTS = 2048
            self.subset = 'val' 

config3 = DummyConfig3()
dataset_val = ShapeNetPartH5(config3)
print(len(dataset_val))

In [None]:
import plotly.express as px
import pandas as pd
import numpy as np

def visualize_point_cloud_interactive(sample):
    """
    sample: a tuple from your dataset, e.g., (points, labels) or (points, labels, seg)
    """
    points,classes,labels = sample
    # labels = dataset_train[1] if len(sample) > 1 else np.zeros(points.shape[0])
    
    # Convert to NumPy if torch tensor
    # if not isinstance(points, np.ndarray):
    #     points = points.numpy()
    # if not isinstance(labels, np.ndarray):
    #     labels = labels.numpy()
    
    df = pd.DataFrame({
        "x": points[:, 0],
        "y": points[:, 1],
        "z": points[:, 2],
        "label": labels
    })
    
    fig = px.scatter_3d(
        df, x="x", y="y", z="z", color="label",
        labels={"label": "Classes"}, opacity=0.7
    )
    fig.update_traces(marker=dict(size=3, line=dict(width=1, color='DarkSlateGrey')), selector=dict(mode='markers'))
    fig.update_layout(
        title="Interactive Point Cloud Visualization",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        legend_title="Labels"
    )
    fig.show()

# Example usage
visualize_point_cloud_interactive(dataset_train[100])
# visualize_point_cloud_interactive(dataset_train[300])


## classification 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

# ======= CONFIG =======
EPOCHS = 30
BATCH_SIZE = 32
NUM_POINTS = 2048
NUM_CLASSES = len(category_to_index)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ======= POINTNET CLASSIFIER =======
class PointNetCls(nn.Module):
    def __init__(self, num_classes):
        super(PointNetCls, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        # x: (B, N, 3)
        x = x.transpose(2, 1)               # -> (B, 3, N)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=False)[0]  # Global Max Pooling -> (B, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# ======= MODEL + OPTIMIZER =======
model = PointNetCls(NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

# ======= DATALOADERS =======
train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)

# ======= TRAINING LOOP =======
for epoch in range(EPOCHS):
    model.train()
    total_loss, total_correct = 0, 0

    for points, category, seg in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        points, category = points.to(DEVICE, dtype=torch.float32), category.to(DEVICE, dtype=torch.long).squeeze()

        optimizer.zero_grad()
        preds = model(points)
        loss = criterion(preds, category)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * points.size(0)
        total_correct += preds.argmax(1).eq(category).sum().item()

    avg_loss = total_loss / len(dataset_train)
    avg_acc = total_correct / len(dataset_train)
    print(f"Train Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")

    # ======= VALIDATION =======
    model.eval()
    val_correct, val_loss = 0, 0
    with torch.no_grad():
        for points, category, seg in val_loader:
            points, category = points.to(DEVICE, dtype=torch.float32), category.to(DEVICE, dtype=torch.long).squeeze()
            preds = model(points)
            loss = criterion(preds, category)
            val_loss += loss.item() * points.size(0)
            val_correct += preds.argmax(1).eq(category).sum().item()

    val_loss /= len(dataset_val)
    val_acc = val_correct / len(dataset_val)
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")


## segmentation without input and feature transformation



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PointNetSeg(nn.Module):
    def __init__(self, num_classes, num_categories):
        super(PointNetSeg, self).__init__()
        # Input Transform + Local Feature Extractor
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 128, 1)
        self.bn3 = nn.BatchNorm1d(128)
        self.conv4 = nn.Conv1d(128, 512, 1)
        self.bn4 = nn.BatchNorm1d(512)
        self.conv5 = nn.Conv1d(512, 1024, 1)
        self.bn5 = nn.BatchNorm1d(1024)

        # MLP for segmentation prediction
        self.conv6 = nn.Conv1d(1168, 512, 1)
        self.bn6 = nn.BatchNorm1d(512)
        self.conv7 = nn.Conv1d(512, 256, 1)
        self.bn7 = nn.BatchNorm1d(256)
        self.conv8 = nn.Conv1d(256, 128, 1)
        self.bn8 = nn.BatchNorm1d(128)
        self.conv9 = nn.Conv1d(128, num_classes, 1)

        # Category embedding (each object category as one-hot)
        self.category_embed = nn.Linear(num_categories, 16)

    def forward(self, x, category_label):
        # x: (B, N, 3)
        # category_label: (B,) -> one-hot
        B, N, _ = x.size()
        x = x.transpose(2, 1)  # (B, 3, N)

        # Extract features
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        pointfeat = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(pointfeat)))
        x = self.bn5(self.conv5(x))

        # Global feature
        x_global = torch.max(x, 2, keepdim=True)[0]  # (B, 2048, 1)
        
        # Category one-hot embedding
        category_onehot = F.one_hot(category_label, num_classes=self.category_embed.in_features).float().to(x.device)
        category_embed = self.category_embed(category_onehot)  # (B, 16)
        category_embed = category_embed.unsqueeze(2).repeat(1, 1, N)  # (B, 16, N)

        # Concatenate features: pointfeat + global + category
        x_global_expanded = x_global.repeat(1, 1, N)  # (B, 2048, N)
        concat_feat = torch.cat([pointfeat, x_global_expanded, category_embed], 1)  # (B, 3024, N)

        # MLP for per-point segmentation
        x = F.relu(self.bn6(self.conv6(concat_feat)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.conv9(x)
        x = x.transpose(2, 1).contiguous()  # (B, N, num_classes)
        return F.log_softmax(x, dim=-1)


In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import torch

NUM_CLASSES = 50      # number of part segmentation labels (ShapeNetPart)
NUM_CATEGORIES = len(category_to_index)
EPOCHS = 50
BATCH_SIZE = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)

model = PointNetSeg(num_classes=NUM_CLASSES, num_categories=NUM_CATEGORIES).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss, total_correct, total_points = 0, 0, 0

    # ========== TRAINING ==========
    for points, category, seg in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        points = points.to(DEVICE, dtype=torch.float32)
        category = category.to(DEVICE, dtype=torch.long).squeeze()
        seg = seg.to(DEVICE, dtype=torch.long)

        optimizer.zero_grad()
        preds = model(points, category)  # (B, N, num_classes)
        preds_flat = preds.view(-1, NUM_CLASSES)
        seg_flat = seg.view(-1)

        loss = criterion(preds_flat, seg_flat)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred_labels = preds_flat.argmax(1)
        total_correct += pred_labels.eq(seg_flat).sum().item()
        total_points += seg_flat.numel()

    train_loss = total_loss / len(train_loader)
    train_acc = total_correct / total_points

    # ========== VALIDATION ==========
    model.eval()
    val_loss, val_correct, val_points = 0, 0, 0
    with torch.no_grad():
        for points, category, seg in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            points = points.to(DEVICE, dtype=torch.float32)
            category = category.to(DEVICE, dtype=torch.long).squeeze()
            seg = seg.to(DEVICE, dtype=torch.long)

            preds = model(points, category)
            preds_flat = preds.view(-1, NUM_CLASSES)
            seg_flat = seg.view(-1)

            loss = criterion(preds_flat, seg_flat)
            val_loss += loss.item()
            val_correct += preds_flat.argmax(1).eq(seg_flat).sum().item()
            val_points += seg_flat.numel()

    val_loss /= len(val_loader)
    val_acc = val_correct / val_points
    val_error = 1 - val_acc

    print(f"\nEpoch [{epoch+1}/{EPOCHS}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val Error: {val_error:.4f}\n")


In [None]:
def random_rotate(points):
    theta = np.random.uniform(0, 2 * np.pi)
    rotation_matrix = np.array([
        [np.cos(theta), -np.sin(theta), 0],
        [np.sin(theta), np.cos(theta), 0],
        [0, 0, 1]
    ])
    return points @ rotation_matrix  # (N,3)

def random_scale(points, scale_low=0.8, scale_high=1.25):
    scale = np.random.uniform(scale_low, scale_high)
    return points * scale

def random_jitter(points, sigma=0.01, clip=0.05):
    jitter = np.clip(sigma * np.random.randn(*points.shape), -clip, clip)
    return points + jitter


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------- Transform Network (T-Net) ----------
class TNet(nn.Module):
    def __init__(self, k=3):
        super(TNet, self).__init__()
        self.k = k
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn3 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.bn4 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, k * k)

        # Initialize bias to identity transform
        self.fc3.weight.data.zero_()
        self.fc3.bias.data.copy_(torch.eye(k).view(-1))

    def forward(self, x):
        B = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=False)[0]

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        # Reshape into transformation matrix
        x = x.view(B, self.k, self.k)
        return x


# ---------- PointNet Segmentation Network ----------
class PointNetSeg(nn.Module):
    def __init__(self, num_classes, num_categories):
        super(PointNetSeg, self).__init__()
        self.input_transform = TNet(k=3)
        self.feature_transform = TNet(k=64)

        # Local feature extractor
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 128, 1)
        self.bn3 = nn.BatchNorm1d(128)
        self.conv4 = nn.Conv1d(128, 512, 1)
        self.bn4 = nn.BatchNorm1d(512)
        self.conv5 = nn.Conv1d(512, 1024, 1)
        self.bn5 = nn.BatchNorm1d(1024)

        # Category embedding
        self.category_embed = nn.Linear(num_categories, 16)

        # MLP for segmentation prediction
        self.conv6 = nn.Conv1d(1168, 512, 1)
        self.bn6 = nn.BatchNorm1d(512)
        self.conv7 = nn.Conv1d(512, 256, 1)
        self.bn7 = nn.BatchNorm1d(256)
        self.conv8 = nn.Conv1d(256, 128, 1)
        self.bn8 = nn.BatchNorm1d(128)
        self.conv9 = nn.Conv1d(128, num_classes, 1)

    def forward(self, x, category_label):
        # x: (B, N, 3)
        B, N, _ = x.size()
        x = x.transpose(2, 1)  # (B, 3, N)

        # -------- Input Transform --------
        trans_input = self.input_transform(x)  # (B, 3, 3)
        x = torch.bmm(trans_input, x)          # align points

        # -------- Local Feature Extraction --------
        x = F.relu(self.bn1(self.conv1(x)))

        # -------- Feature Transform --------
        trans_feat = self.feature_transform(x)  # (B, 64, 64)
        x = torch.bmm(trans_feat, x)

        # Continue feature extraction
        pointfeat = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(pointfeat)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.bn5(self.conv5(x))
        x_global = torch.max(x, 2, keepdim=True)[0]  # (B, 1024, 1)

        # -------- Category Embedding --------
        category_onehot = F.one_hot(category_label, num_classes=self.category_embed.in_features).float().to(x.device)
        category_embed = self.category_embed(category_onehot)  # (B, 16)
        category_embed = category_embed.unsqueeze(2).repeat(1, 1, N)

        # -------- Concatenate Features --------
        x_global_expanded = x_global.repeat(1, 1, N)  # (B, 1024, N)
        concat_feat = torch.cat([pointfeat, x_global_expanded, category_embed], 1)  # (B, 1152, N)

        # -------- Per-point Segmentation --------
        x = F.relu(self.bn6(self.conv6(concat_feat)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.conv9(x)
        x = x.transpose(2, 1).contiguous()
        return F.log_softmax(x, dim=-1), trans_input, trans_feat


In [None]:
def feature_transform_regularizer(trans):
    # Encourage feature transform matrix to be close to orthogonal
    B, K, _ = trans.size()
    I = torch.eye(K, device=trans.device)
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)))
    return loss


In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import torch

NUM_CLASSES = 50      # number of part segmentation labels (ShapeNetPart)
NUM_CATEGORIES = len(category_to_index)
EPOCHS = 30
BATCH_SIZE = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feat_loss_weight = 0.001

train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) 
val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)

model = PointNetSeg(num_classes=NUM_CLASSES, num_categories=NUM_CATEGORIES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)

criterion = nn.NLLLoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss, total_correct, total_points = 0, 0, 0

    # ========== TRAINING ==========
    for points, category, seg in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        
        points = random_rotate(points)
        points = random_scale(points)
        points = random_jitter(points)
        points = points.to(DEVICE, dtype=torch.float32)
        category = category.to(DEVICE, dtype=torch.long).squeeze()
        seg = seg.to(DEVICE, dtype=torch.long)

        optimizer.zero_grad()
        preds,trans_input,trans_feat = model(points, category)  # (B, N, num_classes)
        feat_loss = feature_transform_regularizer(trans_feat)
        preds_flat = preds.view(-1, NUM_CLASSES)
        seg_flat = seg.view(-1)

        seg_loss = criterion(preds_flat, seg_flat)
        loss = seg_loss + feat_loss_weight*feat_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred_labels = preds_flat.argmax(1)
        total_correct += pred_labels.eq(seg_flat).sum().item()
        total_points += seg_flat.numel()

    train_loss = total_loss / len(train_loader)
    train_acc = total_correct / total_points

    # ========== VALIDATION ==========
    model.eval()
    val_loss, val_correct, val_points = 0, 0, 0
    with torch.no_grad():
        for points, category, seg in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            points = points.to(DEVICE, dtype=torch.float32)
            category = category.to(DEVICE, dtype=torch.long).squeeze()
            seg = seg.to(DEVICE, dtype=torch.long)

            preds,trans_input,trans_feat  = model(points, category)
            preds_flat = preds.view(-1, NUM_CLASSES)
            seg_flat = seg.view(-1)
            feat_loss = feature_transform_regularizer(trans_feat)
            seg_loss = criterion(preds_flat, seg_flat)
            loss = seg_loss + feat_loss_weight*feat_loss
            val_loss += loss.item()
            val_correct += preds_flat.argmax(1).eq(seg_flat).sum().item()
            val_points += seg_flat.numel()

    val_loss /= len(val_loader)
    val_acc = val_correct / val_points
    val_error = 1 - val_acc

    print(f"\nEpoch [{epoch+1}/{EPOCHS}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val Error: {val_error:.4f}\n")


In [None]:
import torch
import torch.nn.functional as F
import time
import numpy as np
import psutil
import os
test_loader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
def compute_metrics(model, dataloader, num_classes, device):
    model.eval()
    total_correct, total_seen = 0, 0
    total_loss = 0.0
    all_iou_per_instance = []
    part_intersection = np.zeros(num_classes)
    part_union = np.zeros(num_classes)
    criterion = torch.nn.NLLLoss()

    start_time = time.time()
    with torch.no_grad():
        for points, category, seg in dataloader:
            points = points.to(device, dtype=torch.float32)
            category = category.to(device, dtype=torch.long).squeeze()
            seg = seg.to(device, dtype=torch.long)

            preds,trans_input,trans_feat = model(points, category)  # (B, N, num_classes)
            preds_flat = preds.view(-1, num_classes)
            seg_flat = seg.view(-1)
            feat_loss = feature_transform_regularizer(trans_feat)
            seg_loss = criterion(preds_flat, seg_flat)
            loss = seg_loss + feat_loss_weight*feat_loss
            total_loss += loss.item()

            pred_choice = preds_flat.argmax(1)
            total_correct += pred_choice.eq(seg_flat).sum().item()
            total_seen += seg_flat.numel()

            # --- Compute IoU per instance ---
            preds_np = pred_choice.cpu().numpy().reshape(points.size(0), -1)
            seg_np = seg.cpu().numpy().reshape(points.size(0), -1)
            for shape_idx in range(points.size(0)):
                part_iou = []
                for part in np.unique(seg_np[shape_idx]):
                    I = np.sum((preds_np[shape_idx] == part) & (seg_np[shape_idx] == part))
                    U = np.sum((preds_np[shape_idx] == part) | (seg_np[shape_idx] == part))
                    if U == 0:
                        iou = 1.0
                    else:
                        iou = I / float(U)
                    part_iou.append(iou)
                    part_intersection[part] += I
                    part_union[part] += U
                all_iou_per_instance.append(np.mean(part_iou))

    # --- Metrics ---
    total_time = time.time() - start_time
    num_samples = len(dataloader.dataset)
    avg_inference_time = total_time / num_samples

    overall_acc = total_correct / total_seen
    instance_miou = np.mean(all_iou_per_instance)
    class_miou = np.mean(part_intersection / np.maximum(part_union, 1e-6))

    # --- Memory ---
    process = psutil.Process(os.getpid())
    memory_mb = process.memory_info().rss / 1024 ** 2  # in MB

    # --- Model size ---
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    model_size_mb = (param_size + buffer_size) / 1024 ** 2

    return {
        'loss': total_loss / len(dataloader),
        'overall_acc': overall_acc,
        'instance_miou': instance_miou,
        'class_miou': class_miou,
        'avg_inference_time': avg_inference_time,
        'memory_mb': memory_mb,
        'model_size_mb': model_size_mb
    }


In [None]:
metrics = compute_metrics(model, test_loader, NUM_CLASSES, DEVICE)

print(f"\nValidation Metrics:")
print(f"Loss:              {metrics['loss']:.4f}")
print(f"Overall Accuracy:  {metrics['overall_acc']:.4f}")
print(f"Instance mIoU:     {metrics['instance_miou']:.4f}")
print(f"Class mIoU:        {metrics['class_miou']:.4f}")
print(f"Inference Time:    {metrics['avg_inference_time']*1000:.2f} ms per sample")
print(f"Memory Usage:      {metrics['memory_mb']:.2f} MB")
print(f"Model Size:        {metrics['model_size_mb']:.2f} MB\n")


## comparison brooo

In [None]:
    import torch
    # Assuming 'model' is your trained PyTorch model
    torch.save(model.state_dict(), 'pointnetwithtnet2.pth')

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_single_segmentation_with_category(points, seg_gt, seg_pred, category_name="Unknown", shape_idx=0):
    """
    points: (N, 3) numpy array of point cloud coordinates
    seg_gt: (N,) numpy array of ground truth labels
    seg_pred: (N,) numpy array of predicted labels
    category_name: string label for the object category
    shape_idx: integer index of the object
    """

    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        subplot_titles=[
            f"Ground Truth - {category_name} (Object {shape_idx})",
            f"Prediction - {category_name} (Object {shape_idx})"
        ]
    )

    # Ground truth segmentation
    fig.add_trace(go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=seg_gt,
            colorscale='Viridis',
            opacity=0.8,
            colorbar=dict(title="GT Labels")
        ),
        name="Ground Truth"
    ), row=1, col=1)

    # Predicted segmentation
    fig.add_trace(go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=seg_pred,
            colorscale='Rainbow',
            opacity=0.8,
            colorbar=dict(title="Predicted Labels")
        ),
        name="Prediction"
    ), row=1, col=2)

    fig.update_layout(
        title=f"Segmentation Comparison for Category: {category_name}",
        width=1200,
        height=600
    )

    fig.show()


In [None]:
model.eval()
device = DEVICE
num_classes = NUM_CLASSES

# Keep track of which categories we've already visualized
shown_categories = set()

with torch.no_grad():
    for points, category, seg in test_loader:
        points = points.to(device, dtype=torch.float32)
        category = category.to(device, dtype=torch.long).squeeze()
        seg = seg.to(device, dtype=torch.long)

        preds, input_trans, feat_trans = model(points, category)
        preds_flat = preds.view(-1, num_classes)
        seg_flat = seg.view(-1)

        pred_choice = preds_flat.argmax(1)
        preds_np = pred_choice.cpu().numpy().reshape(points.size(0), -1)
        seg_np = seg.cpu().numpy().reshape(points.size(0), -1)

        # Iterate over batch items
        for b in range(points.size(0)):
            cat_id = category[b].item()
            cat_name = index_to_category[cat_id]

            # Skip if already visualized this category
            if cat_name in shown_categories:
                continue

            # Plot one sample of this category
            plot_single_segmentation_with_category(
                points[b].cpu().numpy(),
                seg_np[b],
                preds_np[b],
                category_name=cat_name,
                shape_idx=b
            )

            shown_categories.add(cat_name)
            print(f"âœ… Shown category: {cat_name} ({len(shown_categories)}/16)")

            # Stop once all 16 have been shown
            if len(shown_categories) == 16:
                print("\nðŸŽ‰ Displayed one sample from each of the 16 categories. Done!")
                break

        if len(shown_categories) == 16:
            break

        
