<a href="https://colab.research.google.com/github/alexmoed/MastersProject_Submission/blob/main/scripts/Inference_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#If using colab mount drive first
from google.colab import drive
drive.mount('/content/drive')


 @brief Gaussian Splat Segmentation Pipeline for ScanNet classification
 Integration and organization of existing inference code assisted by Claude AI (Anthropic).
 Multiple prompts used for combining separate model inference workflows into
 unified pipeline, data flow coordination, and result merging logic# (abbreviated from extended conversation).

 Base models and inference code from:
 Pointcept Contributors (2023). Pointcept: A Codebase for Point Cloud Perception Research [online].
 [Accessed 2025]. Available from: "https://github.com/Pointcept/Pointcept".
 Original Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)

INSTALLS

In [None]:
#Numpy versions are a ongoing issue in this install it must be below 1.25 to work 1.24.and 1.24.3 are tested and work
#Uninstall numpy first
#Restart session!!
!pip uninstall -y numpy
!pip uninstall -y numpy  # Run twice

# Find any installs of numpy left
!find /usr/local/lib/python*/site-packages -name "numpy*" -type d

# Delete ALL numpy folders
!rm -rf /usr/local/lib/python*/site-packages/numpy*
!rm -rf /usr/local/lib/python*/dist-packages/numpy*
!rm -rf ~/.local/lib/python*/site-packages/numpy*

# Clean pip cache too
!pip cache purge


!pip install numpy==1.24.0 --no-cache-dir

# Make sure the versions are correct and reload if not it needs to be 1.24.0
import numpy as np
if np.__version__ != "1.24.0":
   print(f"Wrong version: {np.__version__} Please restart runtime and rerun")


print(f"Numpy version: {np.__version__}")
print(f"Numpy location: {np.__file__}")


In [None]:
#Import universal dependencies that arent impacted by order
import os
import sys
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from collections import OrderedDict
from glob import glob
import numpy as np
from collections import OrderedDict
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import os
import gc
import json


This is a verison cuda install that after running once that it saves the download and caches the install for speed. Also handles torch installs

---



In [None]:
# first batch of installs installing cuda and pytorch older versions: This will take 10+ minutes
%cd /content/drive/MyDrive/Pointcept/Installs
exec(open('/content/drive/MyDrive/Pointcept/Installs/setup_cuda_torch.py').read())


In [None]:
# Install pointops that is needed for Sonata decoder head. This you have to run twice

%cd Pointcept/libs/pointops
!pip install -v -e.
%cd /content/drive/MyDrive/Pointcept


In [None]:
 #More installs and configs required
 %cd /content/drive/MyDrive/Pointcept/Installs
 exec(open('/content/drive/MyDrive/Pointcept/Installs/setup_build_env.py').read())

In [None]:
#Load and impliment pointops patches and wrapper
%cd /content/drive/MyDrive/Pointcept/Installs
exec(open('/content/drive/MyDrive/Pointcept/Installs/fix_ballquery_wrapper.py').read())

## Gaussian Splat Segementation Pipeline

In [None]:
#Imports needed just for this pipeline
import pandas as pd
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from pointcept.models import build_model
from pointcept.utils.config import Config
from pointcept.datasets.transform import Compose
from plyfile import PlyData, PlyElement

Config for pipeline

In [None]:
#Living room scene
PLY_PATH = '/content/drive/MyDrive/Pointcept/data/splat/KitchenDiner_cleaned_v014_rotation.ply'
#Kitchen scene
#PLY_PATH = '/content/drive/MyDrive/Pointcept/data/splat/KitchenDiner_cleaned_v010_rotation_just_kitchen.ply'

OUTPUT_DIR = '/content/drive/MyDrive/Pointcept/data/splat_scannet_with_normals/'
SCENE_NAME = 'KitchenDiner_cleaned_v014_rotation'
SCENE_PATH = os.path.join(OUTPUT_DIR, SCENE_NAME)

# Instance segementation
FIRST_PASS_CONFIG = '/content/drive/MyDrive/Pointcept/configs/scannet/insseg-pointgroup-v1m2-0-ptv3-base.py'
FIRST_PASS_CHECKPOINT = '/content/drive/MyDrive/Pointcept/exp/scannet/insseg-pointgroup-sonata_3/model/epoch_1200.pth'
FIRST_PASS_GRID_SIZE = 0.023

# Semantic Segmentation second pass
SONATA_CONFIG = '/content/drive/MyDrive/Pointcept/configs/sonata/semseg-sonata-v1m1-0c-scannet-ft-fast.py'
SONATA_CHECKPOINT = '/content/drive/MyDrive/Pointcept/exp/sonata/semseg-sonata-v1m1-0c-scannet-ft-_full_scene_v003/model/epoch_800.pth'
SECOND_PASS_CONFIG = FIRST_PASS_CONFIG
SECOND_PASS_CHECKPOINT = FIRST_PASS_CHECKPOINT
SECOND_PASS_GRID_SIZE = 0.021

#Instance Scannet 200
SCANNET200_CONFIG = '/content/drive/MyDrive/Pointcept/Pointcept/configs/scannet200/insseg-pointgroup-v1m2-0-ptv3-base_1_scannet_200_v001.py'
SCANNET200_CHECKPOINT = '/content/drive/MyDrive/Pointcept/exp/scannet/insseg-pointgroup-sonata_config_200_1/model/epoch_800.pth'
SCANNET200_GRID_SIZE = 0.015
SCANNET200_CONFIDENCE_THRESHOLD = 0.5

# Point cloud parameters
MODEL_FILTER = 1290000 #number of points that are filtered down to for predictions
VIZ_MAX_POINTS = 50000 #This just controls the number of points ploty visualizes not the actual point count
CONFIDENCE_THRESHOLD = 0.6
LOW_CONFIDENCE_THRESHOLD = 0.3

# Class names
CLASS_NAMES = [
    "wall", "floor", "cabinet", "bed", "chair", "sofa", "table", "door",
    "window", "bookshelf", "picture", "counter", "desk", "curtain",
    "refrigerator", "shower curtain", "toilet", "sink", "bathtub", "otherfurniture"
]

# Convert the ply file to NPY file

In [None]:

def step1_ply_to_numpy():

    print("\n" + "="*60)
    print("STEP 1: PLY TO NUMPY CONVERSION")
    print("="*60)

    # Check if there is files in the scene_path already change SCENE_NAME to do a new npy
    coord_path = os.path.join(SCENE_PATH, 'coord.npy')
    if os.path.exists(coord_path):
        print("Numpy arrays already exist, loading...")
        coord = np.load(os.path.join(SCENE_PATH, 'coord.npy'))
        color = np.load(os.path.join(SCENE_PATH, 'color.npy'))
        normal = np.load(os.path.join(SCENE_PATH, 'normal.npy'))
        print(f"Loaded {len(coord):,} points")
        return coord, color, normal

    #Create output directory if doesnt exist
    os.makedirs(SCENE_PATH, exist_ok=True)

    # Load unmodified PLY
    print("Loading PLY file...")
    plydata = PlyData.read(PLY_PATH)
    vertex_data = plydata['vertex']
    total_points = len(vertex_data)
    print(f"Total points: {total_points:,}")

    #Extract coordinates
    coords = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T

    #Extract normals
    normals = np.vstack([vertex_data['nx'], vertex_data['ny'], vertex_data['nz']]).T

    # Convert SH to RGB for visualization only
    SH_C0 = 0.28209479177387814
    sh_dc = np.vstack([vertex_data['f_dc_0'], vertex_data['f_dc_1'], vertex_data['f_dc_2']]).T
    rgb_normalized = 0.5 + SH_C0 * sh_dc
    color = (rgb_normalized.clip(0, 1) * 255).astype(np.uint8)

    # Normalize size and coordinates for model processing to a uniform size based on training data
    centroid = coords.mean(axis=0)
    coords_centered = coords - centroid
    current_extent = coords_centered.max(axis=0) - coords_centered.min(axis=0)
    scale_factor = 8.0 / current_extent.max()
    coord = coords_centered * scale_factor


    normal_mags = np.linalg.norm(normals, axis=1)
    normal = normals / (normal_mags[:, np.newaxis] + 1e-8)

    # Save out arrays that will be used until the last step
    print("Saving numpy arrays...")
    np.save(os.path.join(SCENE_PATH, 'coord.npy'), coord.astype(np.float32))
    np.save(os.path.join(SCENE_PATH, 'color.npy'), color.astype(np.uint8))
    np.save(os.path.join(SCENE_PATH, 'normal.npy'), normal.astype(np.float32))

    print("Step PLY to NPY complete")
    return coord, color, normal


coord, color, normal = step1_ply_to_numpy()

Do instance segementation predictions on the full scene: Scannet 20


looking for furniture and objects

In [None]:
def step2_first_pass(coord_full, color_full, normal_full):

   print("\n" + "="*60)
   print("STEP 2: FIRST PASS INSTANCE SEGMENTATION")
   print("="*60)

   print(f"Processing {len(coord_full):,} points")

   #Fiter down point count that get sent to the model
   if len(coord_full) > MODEL_FILTER:
       print(f"Filtering to {MODEL_FILTER:,} points for model...")
       model_indices = np.random.choice(len(coord_full), MODEL_FILTER, replace=False)
       coord = coord_full[model_indices]
       color = color_full[model_indices]
       normal = normal_full[model_indices]
   else:
       coord = coord_full
       color = color_full
       normal = normal_full
       model_indices = np.arange(len(coord_full))

   #Load model
   print("Loading PointGroup model...")
   cfg = Config.fromfile(FIRST_PASS_CONFIG)
   model = build_model(cfg.model).cuda()
   model.eval()

   checkpoint = torch.load(FIRST_PASS_CHECKPOINT, map_location='cuda')
   weight = OrderedDict()
   for key, value in checkpoint["state_dict"].items():
       if key.startswith("module."):
           key = key[7:]
       weight[key] = value
   model.load_state_dict(weight, strict=True)

   #Load data into a dictionary
   data_dict = {
       'coord': coord.copy(),
       'color': color.copy(),
       'normal': normal.copy(),
       'segment': np.zeros(coord.shape[0], dtype=np.int32),
       'instance': np.full(coord.shape[0], -1, dtype=np.int32),
   }

   #Save transformations into a dictionary
   transform_list = [
       dict(type="CenterShift", apply_z=True),
       dict(type="Copy", keys_dict={
           "coord": "origin_coord",
           "segment": "origin_segment",
           "instance": "origin_instance",
       }),
       dict(type="GridSample", grid_size=FIRST_PASS_GRID_SIZE, hash_type="fnv",
            mode="train", return_grid_coord=True),
       dict(type="CenterShift", apply_z=False),
       dict(type="NormalizeColor"),
       dict(type="InstanceParser", segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1),
   ]

   transform = Compose(transform_list)
   data_dict = transform(data_dict)

   # Converting to tensors
   for key in data_dict.keys():
       if isinstance(data_dict[key], np.ndarray):
           if key in ['segment', 'instance', 'grid_coord', 'origin_segment', 'origin_instance']:
               data_dict[key] = torch.from_numpy(data_dict[key]).long()
           elif key == 'bbox':
               data_dict[key] = torch.from_numpy(data_dict[key]).long()
           else:
               data_dict[key] = torch.from_numpy(data_dict[key]).float()

   for key in data_dict.keys():
       if isinstance(data_dict[key], torch.Tensor):
           data_dict[key] = data_dict[key].cuda()

   # Create features
   if all(k in data_dict for k in ["coord", "color", "normal"]):
       feat = torch.cat([data_dict["coord"], data_dict["color"], data_dict["normal"]], dim=1)
       data_dict["feat"] = feat

   #Add get the lenght of each batch and
   if "coord" in data_dict:
       data_dict["batch"] = torch.zeros(len(data_dict["coord"]), dtype=torch.long).cuda()
       data_dict["offset"] = torch.tensor([len(data_dict["coord"])], dtype=torch.long).cuda()

   if "origin_coord" in data_dict:
       data_dict["origin_offset"] = torch.tensor([len(data_dict["origin_coord"])], dtype=torch.long).cuda()

   # Run inference
   print("Running inference...")
   with torch.no_grad():
       output_dict = model(data_dict)

   # Process outputs
   pred_masks = output_dict['pred_masks']
   pred_scores = output_dict['pred_scores']
   pred_classes = output_dict['pred_classes']

   print(f"Found {len(pred_scores)} predicted instances")
   print(f"High confidence (>{CONFIDENCE_THRESHOLD}): {(pred_scores > CONFIDENCE_THRESHOLD).sum()}")

   # Map back if needed
   if "origin_coord" in data_dict and pointops is not None:
       reverse, _ = pointops.knn_query(
           1, data_dict["coord"].float(), data_dict["offset"].int(),
           data_dict["origin_coord"].float(), data_dict["origin_offset"].int(),
       )
       reverse = reverse.cpu().flatten().long()
       pred_masks = pred_masks[:, reverse]

   #Converting to numpy
   pred_masks = pred_masks.cpu().numpy()
   pred_scores = pred_scores.cpu().numpy()
   pred_classes = pred_classes.cpu().numpy()

   # Create instance predictions
   instance_preds_filtered = np.full(len(coord), -1, dtype=np.int32)
   for i in range(len(pred_scores)):
       if pred_scores[i] < CONFIDENCE_THRESHOLD:
           continue
       mask = pred_masks[i] > 0
       instance_preds_filtered[mask] = i

   #Then map the predictions back to the full point cloud
   instance_preds_full = np.full(len(coord_full), -1, dtype=np.int32)
   instance_preds_full[model_indices] = instance_preds_filtered

   # Create class predictions
   class_preds_full = np.full(len(coord_full), -1, dtype=np.int32)
   for i in range(len(pred_scores)):
       if pred_scores[i] < CONFIDENCE_THRESHOLD:
           continue
       mask_full = instance_preds_full == i
       class_preds_full[mask_full] = pred_classes[i]

   # Save
   instance_metadata = {
       'pred_scores': pred_scores,
       'pred_classes': pred_classes,
       'num_instances': len(pred_scores),
       'high_confidence_instances': (pred_scores > CONFIDENCE_THRESHOLD).sum()
   }
   np.save(os.path.join(SCENE_PATH, 'instance_preds_full.npy'), instance_preds_full)
   np.save(os.path.join(SCENE_PATH, 'class_preds_full.npy'), class_preds_full)
   np.save(os.path.join(SCENE_PATH, 'instance_metadata.npy'), instance_metadata)

   # Visualization
   print("\nVisualizing results...")
   instances_to_show = []
   for i in range(len(pred_scores)):
       if pred_scores[i] >= CONFIDENCE_THRESHOLD:
           instances_to_show.append((i, pred_scores[i], pred_classes[i]))

   print(f"Showing {len(instances_to_show)} high confidence instances")

   # Prepare visualization data
   viz_indices = []
   for inst_id, score, class_id in instances_to_show:
       mask = instance_preds_full == inst_id
       indices = np.where(mask)[0]

       # Sample if too many points
       if len(indices) > VIZ_MAX_POINTS // len(instances_to_show):
           indices = np.random.choice(indices, VIZ_MAX_POINTS // len(instances_to_show), replace=False)
       viz_indices.extend(indices)

   viz_indices = np.array(viz_indices)

   if len(viz_indices) > 0:
       # Create color map for instances
       instance_colors = px.colors.qualitative.Plotly

       traces = []
       for idx, (inst_id, score, class_id) in enumerate(instances_to_show):
           mask = instance_preds_full[viz_indices] == inst_id
           if mask.sum() == 0:
               continue

           points_idx = viz_indices[mask]
           class_name = CLASS_NAMES[class_id] if 0 <= class_id < len(CLASS_NAMES) else f"class_{class_id}"

           trace = go.Scatter3d(
               x=coord_full[points_idx, 0],
               y=coord_full[points_idx, 1],
               z=coord_full[points_idx, 2],
               mode='markers',
               marker=dict(
                   size=2,
                   color=instance_colors[idx % len(instance_colors)],
               ),
               name=f"{class_name} (ID:{inst_id}, Score:{score:.2f})",
               showlegend=True
           )
           traces.append(trace)

       # Create figure
       fig = go.Figure(data=traces)
       fig.update_layout(
           title=f"First Pass Instance Segmentation Results<br>Showing {len(instances_to_show)} high confidence instances",
           scene=dict(
               xaxis_title='X',
               yaxis_title='Y',
               zaxis_title='Z',
               aspectmode='data'
           ),
           width=1200,
           height=800
       )
       fig.show()

   del model
   torch.cuda.empty_cache()

   print("Step 2 complete")
   return instance_preds_full, class_preds_full, pred_scores, pred_classes


instance_preds_full, class_preds_full, pred_scores, pred_classes = step2_first_pass(coord, color, normal)

Semantic Segementation just for the floor and walls

Use the threshold to stop where the floor gets predicted so it does not happen on the ceiling should be automatic but can be adjusted

In [None]:
#protects high confidence predictions from previous step

# Floor height parameters - adjust these as needed
FLOOR_HEIGHT_PERCENTILE = 15  # Use bottom 15% of points as reference for floor height adjust as needed
FLOOR_HEIGHT_THRESHOLD = None  # Manual override - set a specific Z value if needed
FLOOR_HEIGHT_MARGIN = 0.1  # Additional margin above the threshold (in meters)

def step3_semantic_improved(coord, color, normal, class_preds_full, instance_preds_full, pred_scores):

   print(f"Running semantic segmentation on ALL points for better wall/floor detection")


   # Calculate floor height threshold
   z_coords = coord[:, 2]
   if FLOOR_HEIGHT_THRESHOLD is None:
       # Use percentile method
       floor_height_limit = np.percentile(z_coords, FLOOR_HEIGHT_PERCENTILE) + FLOOR_HEIGHT_MARGIN
       print(f"\nUsing automatic floor height limit: {floor_height_limit:.2f}m")
       print(f"  (Based on {FLOOR_HEIGHT_PERCENTILE}th percentile + {FLOOR_HEIGHT_MARGIN}m margin)")
   else:
       floor_height_limit = FLOOR_HEIGHT_THRESHOLD + FLOOR_HEIGHT_MARGIN
       print(f"\nUsing manual floor height limit: {floor_height_limit:.2f}m")


   # Load model
   print("\nLoading SONATA model...")
   cfg = Config.fromfile(SONATA_CONFIG)
   model = build_model(cfg.model).cuda()
   model.eval()

   checkpoint = torch.load(SONATA_CHECKPOINT, map_location='cuda')
   weight = OrderedDict()
   for key, value in checkpoint["state_dict"].items():
       if key.startswith("module."):
           key = key[7:]
       weight[key] = value
   model.load_state_dict(weight, strict=True)


   # Prepare data for ALL points
   data_dict = {
       'coord': coord.copy(),
       'color': color.copy(),
       'normal': normal.copy(),
       'segment': np.zeros(coord.shape[0], dtype=np.int32),
   }

   # Apply transforms
   transform = Compose(cfg.data.val.transform)
   data_dict = transform(data_dict)

   # Convert to tensors
   for key in data_dict.keys():
       if isinstance(data_dict[key], np.ndarray):
           data_dict[key] = torch.from_numpy(data_dict[key])
       if isinstance(data_dict[key], torch.Tensor):
           data_dict[key] = data_dict[key].cuda()

   data_dict['batch'] = torch.zeros(data_dict['coord'].shape[0], dtype=torch.long).cuda()
   data_dict['offset'] = torch.tensor([data_dict['coord'].shape[0]], dtype=torch.long).cuda()


   #Run inference on ALL points in the cloud
   print("Running semantic segmentation on full cloud...")
   with torch.no_grad():
       output = model(data_dict)
       if hasattr(output, 'seg_logits'):
           predictions = output.seg_logits.argmax(dim=-1)
       else:
           predictions = output['seg_logits'].argmax(dim=-1)

   # Handle inverse mapping
   if 'inverse' in data_dict:
       inverse = data_dict['inverse'].cpu()
       semantic_preds = predictions[inverse].cpu().numpy()
   else:
       semantic_preds = predictions.cpu().numpy()

   print(f"Got semantic predictions for all {len(semantic_preds)} points")

   # Analyze what semantic segmentation found
   print("\nSemantic classes found:")
   unique_classes, counts = np.unique(semantic_preds, return_counts=True)
   for cls, count in zip(unique_classes, counts):
       percentage = count / len(semantic_preds) * 100
       class_name = CLASS_NAMES[cls] if 0 <= cls < len(CLASS_NAMES) else f"class_{cls}"
       print(f"  {class_name}: {count:,} points ({percentage:.1f}%)")


   print(f"\nMerging with step 1 results using priority system...")

   # Start with step 1 results
   updated_class_preds = class_preds_full.copy()
   wall_floor_attribute = np.zeros(len(coord), dtype=np.uint8)

   # High confidence step 1 instances (protect these)
   high_conf_mask = np.zeros(len(coord), dtype=bool)
   if len(pred_scores) > 0:
       for i in range(len(pred_scores)):
           if pred_scores[i] >= CONFIDENCE_THRESHOLD:
               high_conf_mask |= (instance_preds_full == i)

   #Apply semantic results with priority rules
   walls_from_semantic = 0
   floors_from_semantic = 0
   walls_override = 0
   floors_override = 0
   floors_rejected_height = 0

   for i in range(len(coord)):
       semantic_class = semantic_preds[i]
       step1_class = class_preds_full[i]
       is_high_conf = high_conf_mask[i]
       point_height = z_coords[i]

       # Keep high-confidence step 1 object classifications
       if is_high_conf and step1_class >= 2:  # Objects from high-conf instances
           # Keep step 1 classification
           continue

       #Use semantic walls/floors everywhere else
       elif semantic_class == 0:  # Wall
           if step1_class == -1:  # Was unclassified
               walls_from_semantic += 1
           elif step1_class >= 2:  # Override low-conf object
               walls_override += 1
           updated_class_preds[i] = 0
           wall_floor_attribute[i] = 1

       elif semantic_class == 1:  # Floor
           # Apply height limit check
           if point_height > floor_height_limit:
               floors_rejected_height += 1
           else:
               if step1_class == -1:  # Was unclassified
                   floors_from_semantic += 1
               elif step1_class >= 2:  # Override low-conf object
                   floors_override += 1
               updated_class_preds[i] = 1
               wall_floor_attribute[i] = 1

       # RULE 3: For other semantic classes, only fill unclassified
       elif step1_class == -1:
           updated_class_preds[i] = semantic_class

   # Show the improvement
   unclassified_before = (class_preds_full == -1).sum()
   unclassified_after = (updated_class_preds == -1).sum()
   print(f"  Unclassified points: {unclassified_before:,} → {unclassified_after:,}")

   # Save results
   np.save(os.path.join(SCENE_PATH, 'semantic_preds_full.npy'), semantic_preds)
   np.save(os.path.join(SCENE_PATH, 'wall_floor_attribute.npy'), wall_floor_attribute)
   np.save(os.path.join(SCENE_PATH, 'class_preds_updated.npy'), updated_class_preds)



   # VISUALIZATION - Show the improvements
   print("\nVisualizing semantic improvements...")

   # Sample points for visualization
   viz_sample = min(VIZ_MAX_POINTS, len(coord))
   if len(coord) > viz_sample:
       viz_indices = np.random.choice(len(coord), viz_sample, replace=False)
   else:
       viz_indices = np.arange(len(coord))

   fig = go.Figure()

   # Walls after merging
   after_walls = (updated_class_preds == 0) & np.isin(np.arange(len(coord)), viz_indices)
   wall_points_after = np.where(after_walls)[0]
   if len(wall_points_after) > 0:
       fig.add_trace(
           go.Scatter3d(
               x=coord[wall_points_after, 0],
               y=coord[wall_points_after, 1],
               z=coord[wall_points_after, 2],
               mode='markers',
               marker=dict(size=2, color='blue', opacity=0.8),
               name=f'Walls ({(updated_class_preds == 0).sum():,})',
               showlegend=True
           )
       )

   #Floors after merging versions
   after_floors = (updated_class_preds == 1) & np.isin(np.arange(len(coord)), viz_indices)
   floor_points_after = np.where(after_floors)[0]
   if len(floor_points_after) > 0:
       fig.add_trace(
           go.Scatter3d(
               x=coord[floor_points_after, 0],
               y=coord[floor_points_after, 1],
               z=coord[floor_points_after, 2],
               mode='markers',
               marker=dict(size=2, color='green', opacity=0.8),
               name=f'Floors ({(updated_class_preds == 1).sum():,})',
               showlegend=True
           )
       )

   # Protected high-confidence predicted objects from previous step
   protected_objects = high_conf_mask & (updated_class_preds >= 2) & np.isin(np.arange(len(coord)), viz_indices)
   object_points = np.where(protected_objects)[0]
   if len(object_points) > 2000:  # Sample if too many
       object_points = np.random.choice(object_points, 2000, replace=False)
   if len(object_points) > 0:
       fig.add_trace(
           go.Scatter3d(
               x=coord[object_points, 0],
               y=coord[object_points, 1],
               z=coord[object_points, 2],
               mode='markers',
               marker=dict(size=2, color='orange', opacity=0.6),
               name='Protected Objects',
               showlegend=True
           )
       )

   # Add floor height limit plane for reference
   x_range = [coord[viz_indices, 0].min(), coord[viz_indices, 0].max()]
   y_range = [coord[viz_indices, 1].min(), coord[viz_indices, 1].max()]

   # Create a grid for the height limit plane
   xx, yy = np.meshgrid(
       np.linspace(x_range[0], x_range[1], 10),
       np.linspace(y_range[0], y_range[1], 10)
   )
   zz = np.full_like(xx, floor_height_limit)

   fig.add_trace(
       go.Surface(
           x=xx, y=yy, z=zz,
           opacity=0.2,
           colorscale=[[0, 'red'], [1, 'red']],
           showscale=False,
           name='Floor Height Limit'
       )
   )

   fig.update_layout(
       title=f"Semantic Segmentation Results with Height Limit<br>"
             f"Walls: {(updated_class_preds == 0).sum():,} | "
             f"Floors: {(updated_class_preds == 1).sum():,} | "
             f"Height Limit: Z={floor_height_limit:.2f}m",
       scene=dict(
           xaxis_title='X',
           yaxis_title='Y',
           zaxis_title='Z',
           aspectmode='data'
       ),
       showlegend=True,
       height=800,
       width=1200
   )

   fig.show()

   #Clear up gpu memory
   del model
   torch.cuda.empty_cache()

   print("Step 3 improved complete")
   return semantic_preds, wall_floor_attribute, updated_class_preds

semantic_preds, wall_floor_attribute, class_preds_updated = step3_semantic_improved(
   coord, color, normal, class_preds_full, instance_preds_full, pred_scores)

#    Display the combined results from steps 2 and 3.


In [None]:

def display_combined_results(coord, instance_preds_full, class_preds_updated):

    print("\n" + "="*60)
    print("COMBINED RESULTS FROM STEPS 2-3")
    print("="*60)

    # Get statistics
    total_points = len(coord)
    classified_points = (class_preds_updated >= 0).sum()
    unclassified_points = (class_preds_updated == -1).sum()

    print(f"Total points: {total_points:,}")
    print(f"Classified points: {classified_points:,} ({classified_points/total_points*100:.1f}%)")
    print(f"Unclassified points: {unclassified_points:,} ({unclassified_points/total_points*100:.1f}%)")

    # Count by category
    wall_count = (class_preds_updated == 0).sum()
    floor_count = (class_preds_updated == 1).sum()
    object_count = (class_preds_updated >= 2).sum()

    print(f"\nPoint distribution:")
    print(f"  Walls: {wall_count:,} points")
    print(f"  Floors: {floor_count:,} points")
    print(f"  Objects (ScanNet 20): {object_count:,} points")

    #Get unique instances
    unique_instances = np.unique(instance_preds_full[instance_preds_full >= 0])
    print(f"\nTotal instances: {len(unique_instances)}")

    traces = []

    # Sample for visualization
    viz_sample = min(VIZ_MAX_POINTS * 2, total_points)
    if total_points > viz_sample:
        sample_indices = np.random.choice(total_points, viz_sample, replace=False)
    else:
        sample_indices = np.arange(total_points)

    #Walls
    wall_mask = (class_preds_updated == 0) & np.isin(np.arange(total_points), sample_indices)
    wall_points = np.where(wall_mask)[0]
    if len(wall_points) > 0:
        traces.append(go.Scatter3d(
            x=coord[wall_points, 0],
            y=coord[wall_points, 1],
            z=coord[wall_points, 2],
            mode='markers',
            marker=dict(size=1, color='lightblue', opacity=0.6),
            name=f"Walls ({wall_count:,} pts)",
            showlegend=True
        ))

    #Floors
    floor_mask = (class_preds_updated == 1) & np.isin(np.arange(total_points), sample_indices)
    floor_points = np.where(floor_mask)[0]
    if len(floor_points) > 0:
        traces.append(go.Scatter3d(
            x=coord[floor_points, 0],
            y=coord[floor_points, 1],
            z=coord[floor_points, 2],
            mode='markers',
            marker=dict(size=1, color='lightgreen', opacity=0.6),
            name=f"Floors ({floor_count:,} pts)",
            showlegend=True
        ))

    #Saved Objects from ScanNet 20
    object_instances = []
    for inst_id in unique_instances:
        mask = instance_preds_full == inst_id
        if mask.sum() > 0:
            class_id = class_preds_updated[mask][0]
            if class_id >= 2:  # Not wall or floor
                object_instances.append((inst_id, class_id, mask.sum()))

    #Sort by size
    object_instances.sort(key=lambda x: x[2], reverse=True)

    #Color palette
    instance_colors = px.colors.qualitative.Plotly * 5

    print(f"\nShowing {len(object_instances)} object instances")

    for idx, (inst_id, class_id, size) in enumerate(object_instances[:30]):  # Show up to 30
        mask = (instance_preds_full == inst_id) & np.isin(np.arange(total_points), sample_indices)
        indices = np.where(mask)[0]

        if len(indices) > 0:
            class_name = CLASS_NAMES[class_id] if 0 <= class_id < len(CLASS_NAMES) else f"class_{class_id}"

            traces.append(go.Scatter3d(
                x=coord[indices, 0],
                y=coord[indices, 1],
                z=coord[indices, 2],
                mode='markers',
                marker=dict(
                    size=2,
                    color=instance_colors[idx % len(instance_colors)],
                    opacity=0.9
                ),
                name=f"{class_name} #{inst_id} ({size:,} pts)",
                showlegend=True
            ))

    #Unclassified points
    unclass_mask = (class_preds_updated == -1) & np.isin(np.arange(total_points), sample_indices)
    unclass_points = np.where(unclass_mask)[0]
    if len(unclass_points) > 0:
        traces.append(go.Scatter3d(
            x=coord[unclass_points, 0],
            y=coord[unclass_points, 1],
            z=coord[unclass_points, 2],
            mode='markers',
            marker=dict(size=1, color='gray', opacity=0.3),
            name=f"Unclassified ({unclassified_points:,} pts)",
            showlegend=True
        ))

    #Create the figure in ploty
    fig = go.Figure(data=traces)
    fig.update_layout(
        title=f"Combined Results: Steps 2-3 (ScanNet 20 + Walls/Floors)",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        width=1400,
        height=900,
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99,
            bgcolor="rgba(255, 255, 255, 0.8)",
            font=dict(size=10)
        )
    )
    fig.show()

    print("\nClass distribution:")
    unique_classes, counts = np.unique(class_preds_updated[class_preds_updated >= 0], return_counts=True)
    for cls, count in zip(unique_classes, counts):
        class_name = CLASS_NAMES[cls] if cls < len(CLASS_NAMES) else f"class_{cls}"
        print(f"  {class_name}: {count:,} points")


display_combined_results(coord, instance_preds_full, class_preds_updated)

#Scannet 200: searching for smaller objects and detail thats missed.

In [None]:


# Classes to exclude from ScanNet 200 that overlap from scannet 20
EXCLUDE_CLASSES = [4, 5, 6, 7, 14, 16, 21, 23, 27, 35, 44, 46, 172] #There are overlapping classes from scannet 20 view results below and feel free to add any more classes

def step4_scannet200_filtered(coord, color, normal, instance_preds_full, class_preds_updated,
                              pred_scores, pred_classes, semantic_preds):

    print("\n" + "="*60)
    print("STEP 4: SCANNET 200 FILTERED RESULTS")
    print("="*60)

    print(f"Excluding classes: {EXCLUDE_CLASSES}")

    # Load ScanNet 200 model
    print("\nLoading ScanNet 200 model...")
    cfg = Config.fromfile(SCANNET200_CONFIG)
    model = build_model(cfg.model).cuda()
    model.eval()

    checkpoint = torch.load(SCANNET200_CHECKPOINT, map_location='cuda')
    weight = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if key.startswith("module."):
            key = key[7:]
        weight[key] = value
    model.load_state_dict(weight, strict=True)

    # Prepare data
    data_dict = {
        'coord': coord.copy(),
        'color': color.copy(),
        'normal': normal.copy(),
        'segment': np.zeros(coord.shape[0], dtype=np.int32),
        'instance': np.full(coord.shape[0], -1, dtype=np.int32),
    }

    # Apply transforms
    transform_list = [
        dict(type="CenterShift", apply_z=True),
        dict(type="Copy", keys_dict={
            "coord": "origin_coord",
            "segment": "origin_segment",
            "instance": "origin_instance",
        }),
        dict(type="GridSample", grid_size=SCANNET200_GRID_SIZE, hash_type="fnv",
             mode="train", return_grid_coord=True),
        dict(type="CenterShift", apply_z=False),
        dict(type="NormalizeColor"),
        dict(type="InstanceParser", segment_ignore_index=(-1,), instance_ignore_index=-1),
    ]

    transform = Compose(transform_list)
    data_dict = transform(data_dict)

    # Convert to tensors
    for key in data_dict.keys():
        if isinstance(data_dict[key], np.ndarray):
            if key in ['segment', 'instance', 'grid_coord', 'origin_segment', 'origin_instance']:
                data_dict[key] = torch.from_numpy(data_dict[key]).long()
            elif key == 'bbox':
                data_dict[key] = torch.from_numpy(data_dict[key]).long()
            else:
                data_dict[key] = torch.from_numpy(data_dict[key]).float()

    for key in data_dict.keys():
        if isinstance(data_dict[key], torch.Tensor):
            data_dict[key] = data_dict[key].cuda()

    # Create features
    if all(k in data_dict for k in ["coord", "color", "normal"]):
        feat = torch.cat([data_dict["coord"], data_dict["color"], data_dict["normal"]], dim=1)
        data_dict["feat"] = feat

    # Add batch info
    if "coord" in data_dict:
        data_dict["batch"] = torch.zeros(len(data_dict["coord"]), dtype=torch.long).cuda()
        data_dict["offset"] = torch.tensor([len(data_dict["coord"])], dtype=torch.long).cuda()

    if "origin_coord" in data_dict:
        data_dict["origin_offset"] = torch.tensor([len(data_dict["origin_coord"])], dtype=torch.long).cuda()

    # Run inference
    print("Running ScanNet 200 inference...")
    with torch.no_grad():
        output_dict = model(data_dict)

    # Process outputs
    pred_masks = output_dict['pred_masks']
    pred_scores = output_dict['pred_scores']
    pred_classes = output_dict['pred_classes']

    print(f"\nFound {len(pred_scores)} predicted instances")

    # Map back if needed
    if "origin_coord" in data_dict and pointops is not None:
        reverse, _ = pointops.knn_query(
            1, data_dict["coord"].float(), data_dict["offset"].int(),
            data_dict["origin_coord"].float(), data_dict["origin_offset"].int(),
        )
        reverse = reverse.cpu().flatten().long()
        pred_masks = pred_masks[:, reverse]

    # Convert to numpy
    pred_masks = pred_masks.cpu().numpy()
    pred_scores = pred_scores.cpu().numpy()
    pred_classes = pred_classes.cpu().numpy()

    # Clean up model
    del model
    torch.cuda.empty_cache()

    # Create instance/class predictions for ScanNet 200 ONLY
    instance_preds_s200 = np.full(len(coord), -1, dtype=np.int32)
    class_preds_s200 = np.full(len(coord), -1, dtype=np.int32)

    # Apply filtered ScanNet 200 predictions
    applied_count = 0
    class_counts = {}

    for i in range(len(pred_scores)):
        if pred_scores[i] < SCANNET200_CONFIDENCE_THRESHOLD:
            continue

        # Skip excluded classes
        if pred_classes[i] in EXCLUDE_CLASSES:
            continue

        mask = pred_masks[i] > 0
        if mask.sum() < 50:  # Min points
            continue

        # Apply this instance
        instance_preds_s200[mask] = i
        class_preds_s200[mask] = pred_classes[i]

        applied_count += 1

        # Count classes
        cls = pred_classes[i]
        if cls not in class_counts:
            class_counts[cls] = 0
        class_counts[cls] += 1

    print(f"\nApplied {applied_count} instances (after filtering)")
    print(f"Points classified: {(class_preds_s200 >= 0).sum():,}")

    print(f"\nClass distribution:")
    for cls, count in sorted(class_counts.items()):
        print(f"  Class {cls}: {count} instances")

    # VISUALIZATION - Show ONLY ScanNet 200 filtered results
    print("\nVisualizing ScanNet 200 filtered results...")

    # Sample for visualization
    total_points = len(coord)
    viz_sample = min(VIZ_MAX_POINTS * 2, total_points)
    if total_points > viz_sample:
        sample_indices = np.random.choice(total_points, viz_sample, replace=False)
    else:
        sample_indices = np.arange(total_points)

    traces = []

    # Color by class
    unique_classes = np.unique(class_preds_s200[class_preds_s200 >= 0])
    colors = px.colors.qualitative.Plotly + px.colors.qualitative.Set3 + px.colors.qualitative.Pastel

    print(f"\nVisualizing {len(unique_classes)} unique classes")

    for idx, cls in enumerate(unique_classes[:50]):  # Show up to 50 classes
        class_mask = (class_preds_s200 == cls) & np.isin(np.arange(total_points), sample_indices)
        points = np.where(class_mask)[0]

        if len(points) > 0:
            traces.append(go.Scatter3d(
                x=coord[points, 0],
                y=coord[points, 1],
                z=coord[points, 2],
                mode='markers',
                marker=dict(
                    size=2,
                    color=colors[idx % len(colors)],
                    opacity=0.8
                ),
                name=f"Class {cls} ({(class_preds_s200 == cls).sum()} pts)",
                showlegend=True
            ))

    #Show unclassified (everything not detected by filtered ScanNet 200)
    unclass_mask = (class_preds_s200 == -1) & np.isin(np.arange(total_points), sample_indices)
    if unclass_mask.sum() > 0:
        traces.append(go.Scatter3d(
            x=coord[unclass_mask, 0],
            y=coord[unclass_mask, 1],
            z=coord[unclass_mask, 2],
            mode='markers',
            marker=dict(size=1, color='lightgray', opacity=0.2),
            name=f"Not detected ({(class_preds_s200 == -1).sum()} pts)",
            showlegend=False
        ))

    # Create ploty figure
    fig = go.Figure(data=traces)
    fig.update_layout(
        title=f"ScanNet 200 Filtered: {applied_count} instances, {len(unique_classes)} classes (excluding: {EXCLUDE_CLASSES})",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        width=1400,
        height=900,
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99,
            bgcolor="rgba(255, 255, 255, 0.8)",
            font=dict(size=10)
        )
    )
    fig.show()

    print("\nStep 4 complete")
    return instance_preds_s200, class_preds_s200, pred_scores, pred_classes

instance_preds_s200, class_preds_s200, pred_scores_s200, pred_classes_s200 = step4_scannet200_filtered(
    coord, color, normal, instance_preds_full, class_preds_updated,
    pred_scores, pred_classes, semantic_preds
)

   # Merge ScanNet 200 results with Steps 2-3 results.

In [None]:
def step5_merge_and_display(coord, instance_preds_full, class_preds_updated,
                          instance_preds_s200, class_preds_s200):

   print("\n" + "="*60)
   print("MERGING RESULTS AND DISPLAYING")
   print("="*60)
  # Merge predicions

   # Load predictions from previous steps
   instance_preds_final = instance_preds_full.copy()
   class_preds_final = class_preds_updated.copy()


   replaced_mask = (class_preds_s200 >= 0)  # Where ScanNet 200 has predictions
   replaced_count = replaced_mask.sum()

   #Get total number of exisiting instance ID from Steps 2-3
   max_instance_id = instance_preds_full.max() if instance_preds_full.max() >= 0 else -1

   #This is where we replace the predctions from scannet 200 over the previous predictions
   for i in np.where(replaced_mask)[0]:
       # Get ScanNet 200 instance and class
       s200_instance = instance_preds_s200[i]
       s200_class = class_preds_s200[i]

       #Change instance ID's (offset to avoid conflicting numbers)
       instance_preds_final[i] = s200_instance + max_instance_id + 1
       class_preds_final[i] = s200_class

   # Get final stats
   total_points = len(coord)
   classified_points = (class_preds_final >= 0).sum()
   unclassified_points = (class_preds_final == -1).sum()

   # Save final results
   np.save(os.path.join(SCENE_PATH, 'instance_preds_final.npy'), instance_preds_final)
   np.save(os.path.join(SCENE_PATH, 'class_preds_final.npy'), class_preds_final)

   print(f"Replaced {replaced_count:,} points with ScanNet 200 classifications")
   print(f"\nFinal statistics:")
   print(f"Total points: {total_points:,}")
   print(f"Classified points: {classified_points:,} ({classified_points/total_points*100:.1f}%)")
   print(f"Unclassified points: {unclassified_points:,} ({unclassified_points/total_points*100:.1f}%)")

   traces = []

   # Sample for visualization
   viz_sample = min(VIZ_MAX_POINTS * 2, total_points)
   if total_points > viz_sample:
       sample_indices = np.random.choice(total_points, viz_sample, replace=False)
   else:
       sample_indices = np.arange(total_points)

   # Get all unique instances and their classes
   unique_instances = np.unique(instance_preds_final[instance_preds_final >= 0])

   # Separate original and ScanNet 200 instances
   original_instances = []
   s200_instances = []

   for inst_id in unique_instances:
       mask = instance_preds_final == inst_id
       if mask.sum() > 0:
           class_id = class_preds_final[mask][0]
           if class_id >= 2:  # Not wall or floor
               if inst_id > max_instance_id:
                   s200_instances.append((inst_id, class_id, mask.sum()))
               else:
                   original_instances.append((inst_id, class_id, mask.sum()))

   # Sort by size
   original_instances.sort(key=lambda x: x[2], reverse=True)
   s200_instances.sort(key=lambda x: x[2], reverse=True)

   print(f"\nOriginal instances (ScanNet 20): {len(original_instances)}")
   print(f"ScanNet 200 instances: {len(s200_instances)}")
   print(f"Points replaced by ScanNet 200: {replaced_count:,}")

   # Attempting to get more color variation
   instance_colors = (
       px.colors.qualitative.Plotly +
       px.colors.qualitative.D3 +
       px.colors.qualitative.G10 +
       px.colors.qualitative.T10 +
       px.colors.qualitative.Alphabet +
       px.colors.qualitative.Dark24 +
       px.colors.qualitative.Light24
   )

   # Walls (unchanged from Steps 2-3)
   wall_mask = (class_preds_final == 0) & np.isin(np.arange(total_points), sample_indices)
   wall_points = np.where(wall_mask)[0]
   if len(wall_points) > 0:
       traces.append(go.Scatter3d(
           x=coord[wall_points, 0],
           y=coord[wall_points, 1],
           z=coord[wall_points, 2],
           mode='markers',
           marker=dict(size=1, color='lightblue', opacity=0.6),
           name=f"Walls ({(class_preds_final == 0).sum():,} pts)",
           showlegend=True
       ))

   #Floor classification (unchanged from Steps 2-3)
   floor_mask = (class_preds_final == 1) & np.isin(np.arange(total_points), sample_indices)
   floor_points = np.where(floor_mask)[0]
   if len(floor_points) > 0:
       traces.append(go.Scatter3d(
           x=coord[floor_points, 0],
           y=coord[floor_points, 1],
           z=coord[floor_points, 2],
           mode='markers',
           marker=dict(size=1, color='lightgreen', opacity=0.6),
           name=f"Floors ({(class_preds_final == 1).sum():,} pts)",
           showlegend=True
       ))

   # Plot original instances
   for idx, (inst_id, class_id, size) in enumerate(original_instances[:20]):
       mask = (instance_preds_final == inst_id) & np.isin(np.arange(total_points), sample_indices)
       indices = np.where(mask)[0]

       if len(indices) > 0:
           class_name = CLASS_NAMES[class_id] if 0 <= class_id < len(CLASS_NAMES) else f"class_{class_id}"

           traces.append(go.Scatter3d(
               x=coord[indices, 0],
               y=coord[indices, 1],
               z=coord[indices, 2],
               mode='markers',
               marker=dict(
                   size=2,
                   color=instance_colors[idx % len(instance_colors)],
                   opacity=0.9
               ),
               name=f"{class_name} #{inst_id} ({size:,} pts)",
               showlegend=True
           ))

   # Plot ScanNet 200 instances
   for idx, (inst_id, class_id, size) in enumerate(s200_instances[:20]):
       mask = (instance_preds_final == inst_id) & np.isin(np.arange(total_points), sample_indices)
       indices = np.where(mask)[0]

       if len(indices) > 0:
           traces.append(go.Scatter3d(
               x=coord[indices, 0],
               y=coord[indices, 1],
               z=coord[indices, 2],
               mode='markers',
               marker=dict(
                   size=2,
                   color=instance_colors[(len(original_instances) + idx) % len(instance_colors)],
                   opacity=0.9
               ),
               name=f"S200: Class {class_id} ({size:,} pts)",
               showlegend=True
           ))

   #Unclassified points
   unclass_mask = (class_preds_final == -1) & np.isin(np.arange(total_points), sample_indices)
   unclass_points = np.where(unclass_mask)[0]
   if len(unclass_points) > 0:
       traces.append(go.Scatter3d(
           x=coord[unclass_points, 0],
           y=coord[unclass_points, 1],
           z=coord[unclass_points, 2],
           mode='markers',
           marker=dict(size=1, color='gray', opacity=0.3),
           name=f"Unclassified ({unclassified_points:,} pts)",
           showlegend=True
       ))

   #Create figure
   fig = go.Figure(data=traces)
   fig.update_layout(
       title=f"Final Merged Results: {len(original_instances)} ScanNet 20 + {len(s200_instances)} ScanNet 200 instances",
       scene=dict(
           xaxis_title='X',
           yaxis_title='Y',
           zaxis_title='Z',
           aspectmode='data'
       ),
       width=1400,
       height=900,
       showlegend=True,
       legend=dict(
           yanchor="top",
           y=0.99,
           xanchor="right",
           x=0.99,
           bgcolor="rgba(255, 255, 255, 0.8)",
           font=dict(size=10)
       )
   )
   fig.show()

   print("Results are Merged")
   return instance_preds_final, class_preds_final


instance_preds_final, class_preds_final = step5_merge_and_display(
   coord, instance_preds_full, class_preds_updated,
   instance_preds_s200, class_preds_s200
)

#Reload the orginal unmodified PLY file and assign classifications and export out a new ply file

In [None]:

def step6_export_ply_final(coord, instance_preds_final, class_preds_final,
                          instance_preds_s200, class_preds_s200):

    print("\n" + "="*60)
    print("EXPORTING PLY WITH SCANNET 20 AND 200 ATTRIBUTES")
    print("="*60)

    #Loading the ORIGINAL PLY file - NOT the converted numpy arrays
    print(f"Loading ORIGINAL PLY file: {PLY_PATH}")
    plydata = PlyData.read(PLY_PATH)
    vertex_data = plydata['vertex']
    original_point_count = len(vertex_data)
    print(f"Original PLY has {original_point_count:,} points")

    #Verify array sizes match between both the classes and
    assert len(instance_preds_final) == original_point_count, f"Instance predictions size mismatch"
    assert len(class_preds_final) == original_point_count, f"Class predictions size mismatch"

    # Prepare ScanNet 20 and ScanNet 200 attributes
    scannet20_class = np.full(original_point_count, -1, dtype=np.int32)
    scannet20_instance = np.full(original_point_count, -1, dtype=np.int32)
    scannet200_class = np.full(original_point_count, -1, dtype=np.int32)
    scannet200_instance = np.full(original_point_count, -1, dtype=np.int32)


    for i in range(original_point_count):
        #Check if this point has a ScanNet 200 classification
        if class_preds_s200[i] >= 0:
            # This is a ScanNet 200 point
            if class_preds_s200[i] == 2:  # Floor in S200
                scannet20_class[i] = 1  # Floor in S20
                scannet20_instance[i] = instance_preds_final[i]
            else:
                # If not a floor give it a S200 attribute
                scannet200_class[i] = class_preds_s200[i]
                scannet200_instance[i] = instance_preds_s200[i]
        else:
            # This is a ScanNet 20 points not touched (from steps 2-3)
            if class_preds_final[i] >= 0:
                scannet20_class[i] = class_preds_final[i]
                scannet20_instance[i] = instance_preds_final[i]

    # Print statistics
    s20_count = (scannet20_class >= 0).sum()
    s200_count = (scannet200_class >= 0).sum()
    unclassified = original_point_count - s20_count - s200_count

    print(f"\nClassification summary:")
    print(f"  ScanNet 20 points: {s20_count:,} ({s20_count/original_point_count*100:.1f}%)")
    print(f"  ScanNet 200 points: {s200_count:,} ({s200_count/original_point_count*100:.1f}%)")
    print(f"  Unclassified: {unclassified:,} ({unclassified/original_point_count*100:.1f}%)")

    print("\nPreparing new PLY structure...")
    original_dtype = vertex_data.data.dtype
    new_dtype_list = [(name, original_dtype[name]) for name in original_dtype.names]

    # Add our segmentation attributes
    new_dtype_list.append(('scannet20_class', 'i4'))      # ScanNet 20 class ID (-1 if none)
    new_dtype_list.append(('scannet20_instance', 'i4'))   # ScanNet 20 instance ID (-1 if none)
    new_dtype_list.append(('scannet200_class', 'i4'))     # ScanNet 200 class ID (-1 if none)
    new_dtype_list.append(('scannet200_instance', 'i4'))  # ScanNet 200 instance ID (-1 if none)

    #Create new vertex array
    new_vertex = np.zeros(original_point_count, dtype=new_dtype_list)

    #Copy ALL original data (including spherical harmonics)
    print("Copying ALL original vertex properties...")
    spherical_harmonic_count = 0
    for prop in original_dtype.names:
        new_vertex[prop] = vertex_data[prop]
        if 'f_dc' in prop or 'f_rest' in prop:
            spherical_harmonic_count += 1


    print(f"  Total properties preserved: {len(original_dtype.names)} (including {spherical_harmonic_count} SH coefficients)")

    #Add segementation Attributes
    print("\nAdding segmentation attributes...")
    new_vertex['scannet20_class'] = scannet20_class
    new_vertex['scannet20_instance'] = scannet20_instance
    new_vertex['scannet200_class'] = scannet200_class
    new_vertex['scannet200_instance'] = scannet200_instance

    #Output path
    output_dir = '/content/drive/MyDrive/Pointcept/output'
    os.makedirs(output_dir, exist_ok=True)
    output_filename = f'{SCENE_NAME}_segmented_s20_s200.ply'
    output_path = os.path.join(output_dir, output_filename)

    # Save PLY in same format as original
    print(f"\nSaving Classified PLY to: {output_path}")

    # Check original format
    original_format_is_text = plydata.text
    print(f"Original PLY format: {'text' if original_format_is_text else 'binary'}")

    # Create element and save in same format
    el = PlyElement.describe(new_vertex, 'vertex')
    PlyData([el], text=original_format_is_text).write(output_path)

    # Verify file size
    file_size = os.path.getsize(output_path) / (1024 * 1024)  # MB
    print(f"Output file size: {file_size:.2f} MB")


    # ScanNet 20 class distribution
    print("\nScanNet 20 Classes:")
    unique_s20, counts_s20 = np.unique(scannet20_class[scannet20_class >= 0], return_counts=True)
    for cls, count in zip(unique_s20, counts_s20):
        class_name = CLASS_NAMES[cls] if cls < len(CLASS_NAMES) else f"class_{cls}"
        print(f"  {class_name} (class {cls}): {count:,} points")

    # ScanNet 200 class distribution
    print("\nScanNet 200 Classes:")
    unique_s200, counts_s200 = np.unique(scannet200_class[scannet200_class >= 0], return_counts=True)
    for cls, count in zip(unique_s200, counts_s200):
        print(f"  Class {cls}: {count:,} points")

    # Instance counts
    s20_instances = len(np.unique(scannet20_instance[scannet20_instance >= 0]))
    s200_instances = len(np.unique(scannet200_instance[scannet200_instance >= 0]))
    print(f"\nInstance counts:")
    print(f"  ScanNet 20 instances: {s20_instances}")
    print(f"  ScanNet 200 instances: {s200_instances}")

    # Create summary JSON
    summary = {
        'filename': output_filename,
        'original_ply': PLY_PATH,
        'export_timestamp': pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
        'total_points': int(original_point_count),
        'segmentation_stats': {
            'scannet20_points': int(s20_count),
            'scannet200_points': int(s200_count),
            'unclassified_points': int(unclassified),
            'scannet20_instances': int(s20_instances),
            'scannet200_instances': int(s200_instances)
        },
        'scannet20_classes': {
            (CLASS_NAMES[cls] if cls < len(CLASS_NAMES) else f'class_{cls}'): int(count)
            for cls, count in zip(unique_s20, counts_s20)
        },
        'scannet200_classes': {
            f'class_{cls}': int(count)
            for cls, count in zip(unique_s200, counts_s200)
        }
    }

    # Save summary JSON file
    summary_path = output_path.replace('.ply', '_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"\nSaved summary to: {summary_path}")

    print("\n" + "="*60)
    print("PLY EXPORT COMPLETE!")
    print("="*60)
    print(f"\nOutput file: {output_path}")

     # Print Houdini group expressions

    # Get unique ScanNet 20 instances (excluding -1)
    s20_unique_instances = np.unique(scannet20_instance[scannet20_instance >= 0])
    print(f"\nScanNet 20 - All {len(s20_unique_instances)} instance IDs:")
    for instance_id in sorted(s20_unique_instances):
        print(f"  {instance_id}")

    # Get unique ScanNet 200 instances (excluding -1)
    s200_unique_instances = np.unique(scannet200_instance[scannet200_instance >= 0])
    print(f"\nScanNet 200 - All {len(s200_unique_instances)} instance IDs:")
    for instance_id in sorted(s200_unique_instances):
        print(f"  {instance_id}")

    # Print Houdini group expressions
    print("\n" + "="*60)
    print("Houdini VEX EXPRESSIONS")
    print("="*60)

    print("\nTo select all unclassified points:")
    print("  @scannet20_instance==-1 @scannet200_instance==-1")

    print("\nTo select each ScanNet 20 instance:")
    for instance_id in sorted(s20_unique_instances):
        print(f"  @scannet20_instance=={instance_id}")

    print("\nTo select each ScanNet 200 instance:")
    for instance_id in sorted(s200_unique_instances):
        print(f"  @scannet200_instance=={instance_id}")



    return output_path


output_path = step6_export_ply_final(
    coord, instance_preds_final, class_preds_final,
    instance_preds_s200, class_preds_s200
)

In [None]:
#Run the following to see the scannet 20 and scannet 200 classes optional
for i, class_name in enumerate(CLASS_NAMES):
    print(f"{i}: {class_name}")

from pointcept.datasets.preprocessing.scannet.meta_data.scannet200_constants import CLASS_LABELS_200

# Print all 200 class names
for i, class_name in enumerate(CLASS_LABELS_200):
    print(f"{i}: {class_name}")