In [6]:
import numpy as np
import os

# CHANGE THIS to the path of the folder shown in your screenshot
scene_path = "data/scannet_data/val/scene0011_00" 

# List of files to check
files_to_check = [
    "coord.npy", 
    "color.npy", 
    "normal.npy", 
    "instance.npy", 
    "segment20.npy",
    "segment200.npy"
]

print(f"--- Inspecting scene: {scene_path} ---\n")

for filename in files_to_check:
    file_path = os.path.join(scene_path, filename)
    
    if os.path.exists(file_path):
        # Load the .npy file
        data = np.load(file_path)
        
        print(f"üìÑ File: {filename}")
        print(f"   Shape: {data.shape}")
        print(f"   Type:  {data.dtype}")
        
        # Show min/max to understand value ranges (useful for color/coords)
        if data.size > 0:
            print(f"   Min:   {np.min(data)}")
            print(f"   Max:   {np.max(data)}")
            
            # Print first 2 entries to see what the data actually looks like
            print(f"   Sample data:\n{data[:2]}")
        
        print("-" * 30)
    else:
        print(f"‚ùå File not found: {filename}")

--- Inspecting scene: data/scannet_data/val/scene0011_00 ---

üìÑ File: coord.npy
   Shape: (237360, 3)
   Type:  float32
   Min:   -0.032526493072509766
   Max:   8.213502883911133
   Sample data:
[[2.5091114  0.4083811  0.14877559]
 [2.5156426  0.4059527  0.14168811]]
------------------------------
üìÑ File: color.npy
   Shape: (237360, 3)
   Type:  uint8
   Min:   2
   Max:   255
   Sample data:
[[35 33 38]
 [34 32 39]]
------------------------------
üìÑ File: normal.npy
   Shape: (237360, 3)
   Type:  float32
   Min:   -0.9999940991401672
   Max:   0.9999979734420776
   Sample data:
[[0.19109516 0.92176497 0.3371149 ]
 [0.35799062 0.9311391  0.06885417]]
------------------------------
üìÑ File: instance.npy
   Shape: (237360,)
   Type:  int64
   Min:   -1
   Max:   32
   Sample data:
[27 27]
------------------------------
üìÑ File: segment20.npy
   Shape: (237360,)
   Type:  int64
   Min:   -1
   Max:   19
   Sample data:
[14 14]
------------------------------
üìÑ File: segme

In [1]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset

class ScanNetDataset(Dataset):
    def __init__(self, data_root, transform=None):
        """
        Args:
            data_root (str): Path to the folder containing scene folders 
                             (e.g., 'data/scannet_processed/val').
            transform (callable, optional): Sonata transform pipeline.
        """
        self.data_root = data_root
        self.transform = transform
        
        # specific to your directory structure: data_root/sceneXXXX_XX/*.npy
        # We search for all folders inside data_root
        self.scene_paths = sorted(glob.glob(os.path.join(data_root, "scene*")))
        
        if len(self.scene_paths) == 0:
            raise ValueError(f"No scene folders found in {data_root}. Check your path.")

        print(f"Loaded {len(self.scene_paths)} scenes from {data_root}")

    def __len__(self):
        return len(self.scene_paths)

    def __getitem__(self, idx):
        scene_path = self.scene_paths[idx]
        scene_name = os.path.basename(scene_path)

        # Load the specific .npy files you identified
        try:
            coord = np.load(os.path.join(scene_path, "coord.npy")).astype(np.float32)
            color = np.load(os.path.join(scene_path, "color.npy")).astype(np.float32)
            normal = np.load(os.path.join(scene_path, "normal.npy")).astype(np.float32)
            
            # Load labels if they exist (usually for train/val)
            segment_path = os.path.join(scene_path, "segment20.npy")
            instance_path = os.path.join(scene_path, "instance.npy")
            
            if os.path.exists(segment_path):
                segment = np.load(segment_path).astype(np.int64)
            else:
                segment = np.zeros(coord.shape[0], dtype=np.int64) - 1 # Ignore index

            if os.path.exists(instance_path):
                instance = np.load(instance_path).astype(np.int64)
            else:
                instance = np.zeros(coord.shape[0], dtype=np.int64) - 1

        except FileNotFoundError as e:
            raise FileNotFoundError(f"Missing required .npy file in {scene_name}: {e}")

        # Construct the dictionary expected by Sonata/Pointcept
        data_dict = {
            "coord": coord,
            "color": color,
            "normal": normal,
            "segment20": segment,  
            "instance": instance,
            "name": scene_name,
            "id": idx
        }

        # Apply Sonata transforms (grid sampling, normalization, etc.)
        if self.transform:
            data_dict = self.transform(data_dict)

        return data_dict

In [2]:
train_path = 'data/scannet_data/train'
dataset = ScanNetDataset(data_root=train_path)
print(f"Dataset length: {len(dataset)}")

Loaded 1201 scenes from data/scannet_data/train
Dataset length: 1201


In [3]:
data_item = dataset[0]
print(f"Data item keys: {list(data_item.keys())}")
print(f"Coordinates shape: {data_item['coord'].shape}, sample value {data_item['coord'][:2]}")
print(f"Color shape: {data_item['color'].shape}, sample value {data_item['color'][:2]}")
print(f"Normal shape: {data_item['normal'].shape}, sample value {data_item['normal'][:2]}")
print(f"Segment shape: {data_item['segment20'].shape}, sample value {data_item['segment20'][:2]}")
print(f"Instance shape: {data_item['instance'].shape}, sample value {data_item['instance'][:2]}")

Data item keys: ['coord', 'color', 'normal', 'segment20', 'instance', 'name', 'id']
Coordinates shape: (81369, 3), sample value [[0.5324214  4.5172734  0.26304942]
 [0.53404164 4.552089   0.262302  ]]
Color shape: (81369, 3), sample value [[101. 107.  90.]
 [ 88.  83.  78.]]
Normal shape: (81369, 3), sample value [[ 0.8616817  -0.02587385 -0.5067754 ]
 [ 0.9884297   0.14236335 -0.05226134]]
Segment shape: (81369,), sample value [13 13]
Instance shape: (81369,), sample value [5 5]


In [4]:


import open3d as o3d
import sonata
import torch

try:
    import flash_attn
except ImportError:
    flash_attn = None


def get_pca_color(feat, brightness=1.25, center=True):
    u, s, v = torch.pca_lowrank(feat, center=center, q=6, niter=5)
    projection = feat @ v
    projection = projection[:, :3] * 0.6 + projection[:, 3:6] * 0.4
    min_val = projection.min(dim=-2, keepdim=True)[0]
    max_val = projection.max(dim=-2, keepdim=True)[0]
    div = torch.clamp(max_val - min_val, min=1e-6)
    color = (projection - min_val) / div * brightness
    color = color.clamp(0.0, 1.0)
    return color


if __name__ == "__main__":
    # set random seed
    # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
    # (the pca prevent in paper is with another version of cuda and pytorch environment)
    sonata.utils.set_seed(53124)
    # Load model
    if flash_attn is not None:
        model = sonata.load("sonata", repo_id="facebook/sonata").cuda()
    else:
        custom_config = dict(
            enc_patch_size=[1024 for _ in range(5)],  # reduce patch size if necessary
            enable_flash=False,
        )
        model = sonata.load(
            "sonata", repo_id="facebook/sonata", custom_config=custom_config
        ).cuda()
    # Load default data transform pipeline
    transform = sonata.transform.default()
    
    # Load data
    # point = sonata.data.load("sample1")
    # inspect loaded data
    point = data_item

    
    # point.pop("segment200")
    segment = point.pop("segment20")
    point["segment"] = segment  # two kinds of segment exist in ScanNet, only use one
    original_coord = point["coord"].copy()
    point = transform(point)

    with torch.inference_mode():
        for key in point.keys():
            if isinstance(point[key], torch.Tensor):
                point[key] = point[key].cuda(non_blocking=True)
        # model forward:
        point = model(point)
        # upcast point feature
        # Point is a structure contains all the information during forward
        for _ in range(2):
            assert "pooling_parent" in point.keys()
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = parent
        while "pooling_parent" in point.keys():
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            parent.feat = point.feat[inverse]
            point = parent

        # here point is down-sampled by GridSampling in default transform pipeline
        # feature of point cloud in original scale can be acquired by:
        _ = point.feat[point.inverse]

        # PCA
        pca_color = get_pca_color(point.feat, brightness=1.2, center=True)

    # inverse back to original scale before grid sampling
    # point.inverse is acquired from the GirdSampling transform
    original_pca_color = pca_color[point.inverse]
    pcd = o3d.geometry.PointCloud()
    # pcd.points = o3d.utility.Vector3dVector(original_coord)
    # pcd.colors = o3d.utility.Vector3dVector(original_pca_color.cpu().detach().numpy())
    # o3d.visualization.draw_geometries([pcd])
    # or
    o3d.visualization.draw_plotly([pcd])

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(point.coord.cpu().detach().numpy())
    pcd.colors = o3d.utility.Vector3dVector(pca_color.cpu().detach().numpy())
    o3d.io.write_point_cloud("pca.ply", pcd)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Loading checkpoint from HuggingFace: sonata ...
Model params: 108.46M
