# Semantic Dictionary Segmentation

This notebook tests semantic segmentation using a dictionary of terms with positive queries. Each term's queries are contrasted against negative queries from other dictionary items to segment the environment into semantic compartments.

For example:
- 'tree': ['green', 'leaves', 'bark', 'trunk', 'branches']
- 'ground': ['dirt', 'rocks', 'floor', 'soil', 'path']
- 'sky': ['blue', 'clouds', 'air', 'above']
- 'feeder': ['bird feeder', 'feeding station', 'container']

In [1]:
%load_ext autoreload
%autoreload 2

import pyvista as pv
import numpy as np
import open3d as o3d
from pathlib import Path
from collab_splats.wrapper import Splatter
from collab_splats.utils.mesh import mesh_clustering

# pv.start_xvfb()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Taichi] version 1.7.4, llvm 15.0.4, commit b4b956fd, linux, python 3.10.18


[I 12/10/25 18:52:39.459 41554] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


## Load the splatter from configuration

Load the dataset configuration from YAML and ensure preprocessing/training steps are complete:

In [75]:
# Load splatter from YAML config
splatter = Splatter.from_config_file(
    dataset='birds_date-02062024_video-C0043',
    config_dir='/workspace/collab-splats/docs/splats/configs'
)

# Ensure preprocessing and training are done
# (if already completed, these will skip automatically)
splatter.preprocess()
splatter.extract_features()
splatter.mesh()

transforms.json already exists at /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/preproc/transforms.json
To rerun preprocessing, set overwrite=True
Output already exists for rade-features
To rerun feature extraction, set overwrite=True

Available runs:
[0] 2025-07-25_040743


In [76]:
splatter.load_model()

Loading model from /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/rade-features/2025-07-25_040743/config.yml


✓ Model loaded: RadegsFeaturesModel (step 29999)


(_TrainerConfig(_target=<class 'nerfstudio.engine.trainer.Trainer'>, output_dir=PosixPath('/workspace/fieldwork-data/birds/2024-02-06/environment/C0043'), method_name='rade-features', experiment_name='', project_name='nerfstudio-project', timestamp='2025-07-25_040743', machine=MachineConfig(seed=42, num_devices=1, num_machines=1, machine_rank=0, dist_url='auto', device_type='cuda'), logging=LoggingConfig(relative_log_dir=PosixPath('.'), steps_per_log=10, max_buffer_size=20, local_writer=LocalWriterConfig(_target=<class 'nerfstudio.utils.writer.LocalWriter'>, enable=True, stats_to_track=(<EventName.ITER_TRAIN_TIME: 'Train Iter (time)'>, <EventName.TRAIN_RAYS_PER_SEC: 'Train Rays / Sec'>, <EventName.CURR_TEST_PSNR: 'Test PSNR'>, <EventName.VIS_RAYS_PER_SEC: 'Vis Rays / Sec'>, <EventName.TEST_RAYS_PER_SEC: 'Test Rays / Sec'>, <EventName.ETA: 'ETA (time)'>), max_log_size=10), profiler='basic'), viewer=ViewerConfig(relative_log_filename='viewer_log_filename.txt', websocket_port=None, websoc

## Define Semantic Dictionary

Create a dictionary mapping semantic categories to positive query terms:

In [138]:
# # Semantic dictionary: maps categories to positive query terms
semantic_dictionary = {
    'tree': ['green', 'leaves', 'bark', 'trunk'],
    # 'ground': ['rocks'],
    'feeder': ['bird feeder', 'container', 'food'],
    'brush': ['leaves', 'plants', 'thicket', 'bramble'],
    'gravel': ['gravel', 'rock', 'concrete'],
}

In [131]:
category = 'brush'
positive_queries = semantic_dictionary[category]
negative_queries = [other_category for other_category in semantic_dictionary.keys() if other_category != category]


similarity = splatter.query_mesh(
    positive_queries=positive_queries,
    negative_queries=negative_queries,
    output_fn=f"query-{category}.ply"
)


In [132]:
# Load query mesh
mesh_dir = splatter.config["mesh_info"]["mesh"].parent
query_mesh_path = mesh_dir / f"query-{category}.ply"

splatter.plot_mesh(similarity)

Number of points: 351349
Number of cells: 661879
Bounds: BoundsTuple(x_min=-1.093526840209961, x_max=0.6880093812942505, y_min=-0.27054786682128906, y_max=1.2863465547561646, z_min=-0.42523476481437683, z_max=0.5941177606582642)


Widget(value='<iframe src="http://localhost:34769/index.html?ui=P_0x7fb7308a2ad0_25&reconnect=auto" class="pyv…

## Query Mesh for Each Semantic Category

For each semantic term, use its positive queries and contrast them against negative queries (all terms from other categories in the dictionary):

In [140]:
# Store results for each category
semantic_results = {}

# Query mesh for each semantic category
for category, positive_queries in semantic_dictionary.items():
    # Gather negative queries from all other categories
    negative_queries = []
    positive_queries += [category] # Add category to positive queries

    for other_category, other_queries in semantic_dictionary.items():
        if other_category != category:
            negative_queries.extend([other_category])
            # negative_queries.extend(other_queries)
    
    # Query the mesh with positive vs negative queries
    print(f"\nQuerying '{category}'...")
    print(f"  Positive: {positive_queries}")
    print(f"  Negative: {len(negative_queries)} terms from other categories")
    
    similarity = splatter.query_mesh(
        positive_queries=positive_queries,
        negative_queries=negative_queries,
        output_fn=f"query-{category}.ply"
    )
    
    semantic_results[category] = similarity
    print(f"  Done! Saved to query-{category}.ply")


Querying 'tree'...
  Positive: ['green', 'leaves', 'bark', 'trunk', 'tree']
  Negative: 3 terms from other categories
  Done! Saved to query-tree.ply

Querying 'feeder'...
  Positive: ['bird feeder', 'container', 'food', 'feeder']
  Negative: 3 terms from other categories
  Done! Saved to query-feeder.ply

Querying 'brush'...
  Positive: ['leaves', 'plants', 'thicket', 'bramble', 'brush']
  Negative: 3 terms from other categories
  Done! Saved to query-brush.ply

Querying 'gravel'...
  Positive: ['gravel', 'rock', 'concrete', 'gravel']
  Negative: 3 terms from other categories
  Done! Saved to query-gravel.ply


## Aggregate All Clusters into Categorical Mesh

Cluster all semantic categories and combine them into a single categorical mesh with unique labels for each cluster:

In [153]:
# Process all categories and aggregate clusters
all_clusters = {}  # Store {category: [(cluster_idx, vertex_indices), ...]}
cluster_counter = 0

mesh_dir = splatter.config["mesh_info"]["mesh"].parent

for category in semantic_dictionary.keys():
    print(f"\nProcessing '{category}'...")
    
    # Load query mesh
    query_mesh_path = mesh_dir / f"query-{category}.ply"
    mesh_cat = o3d.io.read_triangle_mesh(str(query_mesh_path))
    similarity = np.asarray(mesh_cat.vertex_colors)[:, 0]
    
    # Cluster
    clusters = mesh_clustering(
        mesh=mesh_cat,
        similarity_values=similarity,
        similarity_threshold=0.95,
        spatial_radius=0.02,
    )
    
    print(f"  Found {len(clusters)} clusters")
    
    # Store clusters with global cluster IDs
    all_clusters[category] = []
    for cluster_vertices in clusters:
        all_clusters[category].append((cluster_counter, np.array(cluster_vertices)))
        cluster_counter += 1
        print(f"    Cluster {cluster_counter-1}: {len(cluster_vertices)} vertices")

print(f"\n\nTotal clusters across all categories: {cluster_counter}")


Processing 'tree'...


Building adjacency matrix: 75777it [00:07, 9928.75it/s] 


  Found 134 clusters
    Cluster 0: 12445 vertices
    Cluster 1: 7667 vertices
    Cluster 2: 23324 vertices
    Cluster 3: 2405 vertices
    Cluster 4: 1110 vertices
    Cluster 5: 141 vertices
    Cluster 6: 66 vertices
    Cluster 7: 122 vertices
    Cluster 8: 144 vertices
    Cluster 9: 65 vertices
    Cluster 10: 358 vertices
    Cluster 11: 86 vertices
    Cluster 12: 319 vertices
    Cluster 13: 66 vertices
    Cluster 14: 40 vertices
    Cluster 15: 47 vertices
    Cluster 16: 1254 vertices
    Cluster 17: 17 vertices
    Cluster 18: 498 vertices
    Cluster 19: 4160 vertices
    Cluster 20: 42 vertices
    Cluster 21: 20 vertices
    Cluster 22: 527 vertices
    Cluster 23: 23 vertices
    Cluster 24: 21 vertices
    Cluster 25: 296 vertices
    Cluster 26: 17 vertices
    Cluster 27: 30 vertices
    Cluster 28: 28 vertices
    Cluster 29: 41 vertices
    Cluster 30: 77 vertices
    Cluster 31: 22 vertices
    Cluster 32: 36 vertices
    Cluster 33: 59 vertices
    Cluster 3

Building adjacency matrix: 6440it [00:00, 12110.98it/s]


  Found 79 clusters
    Cluster 134: 16 vertices
    Cluster 135: 38 vertices
    Cluster 136: 25 vertices
    Cluster 137: 11 vertices
    Cluster 138: 28 vertices
    Cluster 139: 24 vertices
    Cluster 140: 103 vertices
    Cluster 141: 105 vertices
    Cluster 142: 63 vertices
    Cluster 143: 26 vertices
    Cluster 144: 1602 vertices
    Cluster 145: 90 vertices
    Cluster 146: 28 vertices
    Cluster 147: 48 vertices
    Cluster 148: 13 vertices
    Cluster 149: 82 vertices
    Cluster 150: 12 vertices
    Cluster 151: 26 vertices
    Cluster 152: 117 vertices
    Cluster 153: 20 vertices
    Cluster 154: 64 vertices
    Cluster 155: 185 vertices
    Cluster 156: 53 vertices
    Cluster 157: 12 vertices
    Cluster 158: 22 vertices
    Cluster 159: 54 vertices
    Cluster 160: 42 vertices
    Cluster 161: 29 vertices
    Cluster 162: 90 vertices
    Cluster 163: 50 vertices
    Cluster 164: 13 vertices
    Cluster 165: 34 vertices
    Cluster 166: 54 vertices
    Cluster 167: 

Building adjacency matrix: 9800it [00:00, 10073.02it/s]


  Found 82 clusters
    Cluster 213: 64 vertices
    Cluster 214: 15 vertices
    Cluster 215: 55 vertices
    Cluster 216: 22 vertices
    Cluster 217: 76 vertices
    Cluster 218: 85 vertices
    Cluster 219: 44 vertices
    Cluster 220: 37 vertices
    Cluster 221: 28 vertices
    Cluster 222: 1805 vertices
    Cluster 223: 53 vertices
    Cluster 224: 64 vertices
    Cluster 225: 153 vertices
    Cluster 226: 73 vertices
    Cluster 227: 11 vertices
    Cluster 228: 62 vertices
    Cluster 229: 403 vertices
    Cluster 230: 134 vertices
    Cluster 231: 20 vertices
    Cluster 232: 83 vertices
    Cluster 233: 55 vertices
    Cluster 234: 28 vertices
    Cluster 235: 58 vertices
    Cluster 236: 377 vertices
    Cluster 237: 820 vertices
    Cluster 238: 15 vertices
    Cluster 239: 54 vertices
    Cluster 240: 13 vertices
    Cluster 241: 23 vertices
    Cluster 242: 324 vertices
    Cluster 243: 143 vertices
    Cluster 244: 31 vertices
    Cluster 245: 22 vertices
    Cluster 24

Building adjacency matrix: 14056it [00:01, 9912.70it/s] 


  Found 31 clusters
    Cluster 295: 62 vertices
    Cluster 296: 45 vertices
    Cluster 297: 228 vertices
    Cluster 298: 506 vertices
    Cluster 299: 6395 vertices
    Cluster 300: 4053 vertices
    Cluster 301: 379 vertices
    Cluster 302: 50 vertices
    Cluster 303: 13 vertices
    Cluster 304: 1308 vertices
    Cluster 305: 14 vertices
    Cluster 306: 13 vertices
    Cluster 307: 97 vertices
    Cluster 308: 93 vertices
    Cluster 309: 16 vertices
    Cluster 310: 127 vertices
    Cluster 311: 84 vertices
    Cluster 312: 32 vertices
    Cluster 313: 62 vertices
    Cluster 314: 11 vertices
    Cluster 315: 28 vertices
    Cluster 316: 19 vertices
    Cluster 317: 16 vertices
    Cluster 318: 15 vertices
    Cluster 319: 79 vertices
    Cluster 320: 14 vertices
    Cluster 321: 19 vertices
    Cluster 322: 31 vertices
    Cluster 323: 33 vertices
    Cluster 324: 20 vertices
    Cluster 325: 14 vertices


Total clusters across all categories: 326


### Create Categorical Mesh with Cluster Labels

Create a single mesh with cluster labels for each vertex:

In [154]:
# Load the base mesh (we'll use the first category's mesh as the base)
base_category = list(semantic_dictionary.keys())[0]
base_mesh_path = mesh_dir / f"query-{base_category}.ply"
categorical_mesh = o3d.io.read_triangle_mesh(str(base_mesh_path))

# Create cluster label array (-1 means no cluster assigned)
num_vertices = len(categorical_mesh.vertices)
cluster_labels = -np.ones(num_vertices, dtype=np.int32)

# Assign cluster labels
for category_id, (category, category_clusters) in enumerate(all_clusters.items()):
    for cluster_id, vertex_indices in category_clusters:
        cluster_labels[vertex_indices] = category_id

# Count how many vertices are in clusters
num_clustered = (cluster_labels >= 0).sum()
print(f"Assigned {num_clustered}/{num_vertices} vertices to clusters")
print(f"Cluster labels range: {cluster_labels.min()} to {cluster_labels.max()}")

# Save cluster labels as a scalar field in the mesh
# Store as colors for now (we'll create proper visualization later)
# Normalize cluster IDs to 0-1 range for color mapping
cluster_labels_normalized = cluster_labels.copy().astype(float)
cluster_labels_normalized[cluster_labels >= 0] = cluster_labels[cluster_labels >= 0] / cluster_counter

# Save categorical mesh with cluster labels
categorical_mesh_path = mesh_dir / "categorical_mesh.ply"
# We'll store the actual cluster labels separately
np.save(mesh_dir / "cluster_labels.npy", cluster_labels)
print(f"\nSaved cluster labels to {mesh_dir / 'cluster_labels.npy'}")

Assigned 103504/351349 vertices to clusters
Cluster labels range: -1 to 3

Saved cluster labels to /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/rade-features/2025-07-25_040743/mesh/cluster_labels.npy


## Visualize All Clusters

Create a comprehensive visualization showing:
1. All clusters together with unique colors
2. Individual cluster views (each cluster highlighted, rest black)

In [169]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Generate unique colors for each cluster
def generate_distinct_colors(n):
    """Generate n visually distinct colors"""
    if n <= 10:
        # Use tab10 for small number of clusters
        cmap = plt.cm.get_cmap('tab10')
        return [cmap(i) for i in range(n)]
    elif n <= 20:
        # Use tab20 for medium number
        cmap = plt.cm.get_cmap('tab20')
        return [cmap(i) for i in range(n)]
    else:
        # Use hsv for many clusters
        cmap = plt.cm.get_cmap('hsv')
        return [cmap(i / n) for i in range(n)]

# Generate colors for all clusters
n_clusters = len(all_clusters.items())
cluster_colors = generate_distinct_colors(n_clusters)
print(f"Generated {len(cluster_colors)} distinct colors for {n_clusters} clusters")

Generated 4 distinct colors for 4 clusters


  cmap = plt.cm.get_cmap('tab10')


### View All Clusters Together

Show all clusters in a single view with unique colors:

In [172]:
import copy
# Create mesh with all clusters colored
all_clusters_mesh = copy.deepcopy(categorical_mesh)
colors_all = np.zeros((num_vertices, 3))

# Color each vertex by its cluster
for vertex_idx in range(num_vertices):
    cluster_id = cluster_labels[vertex_idx]
    if cluster_id >= 0:
        # Use cluster color (take RGB from the color tuple)
        colors_all[vertex_idx] = cluster_colors[cluster_id][:3]
    else:
        # Non-clustered vertices are black
        colors_all[vertex_idx] = [0,0,0]

all_clusters_mesh.vertex_colors = o3d.utility.Vector3dVector(colors_all)

# Save the all-clusters mesh
all_clusters_path = mesh_dir / "all_clusters.ply"
o3d.io.write_triangle_mesh(str(all_clusters_path), all_clusters_mesh)
print(f"Saved all-clusters mesh to {all_clusters_path}")

# Visualize
print("\nVisualizing all clusters together:")
splatter.plot_mesh(colors_all)

Saved all-clusters mesh to /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/rade-features/2025-07-25_040743/mesh/all_clusters.ply

Visualizing all clusters together:
Number of points: 351349
Number of cells: 661879
Bounds: BoundsTuple(x_min=-1.093526840209961, x_max=0.6880093812942505, y_min=-0.27054786682128906, y_max=1.2863465547561646, z_min=-0.42523476481437683, z_max=0.5941177606582642)


Widget(value='<iframe src="http://localhost:34769/index.html?ui=P_0x7fb71c68ae00_36&reconnect=auto" class="pyv…

### Create Grid Visualization of Individual Clusters

Generate a grid showing each cluster individually (highlighted with unique color, rest black):

In [None]:
# Calculate grid dimensions for subplots
import math

n_clusters = cluster_counter
n_cols = min(4, n_clusters)  # Max 4 columns
n_rows = math.ceil(n_clusters / n_cols)

print(f"Creating grid visualization: {n_rows} rows x {n_cols} columns for {n_clusters} clusters")

# Create plotter with subplots
plotter = pv.Plotter(shape=(n_rows, n_cols), window_size=[1600, 400 * n_rows])

# Convert Open3D mesh to PyVista mesh
vertices = np.asarray(categorical_mesh.vertices)
faces = np.asarray(categorical_mesh.triangles)
# PyVista format: [n_points, p0, p1, p2, n_points, p0, p1, p2, ...]
faces_pv = np.hstack([np.full((len(faces), 1), 3), faces]).flatten()
pv_mesh = pv.PolyData(vertices, faces_pv)

# Flatten all_clusters into a list for easier iteration
all_clusters_flat = []
for category, category_clusters in all_clusters.items():
    for cluster_id, vertex_indices in category_clusters:
        all_clusters_flat.append({
            'category': category,
            'cluster_id': cluster_id,
            'vertices': vertex_indices
        })

# Add each cluster to a subplot
for idx, cluster_info in enumerate(all_clusters_flat):
    row = idx // n_cols
    col = idx % n_cols
    
    plotter.subplot(row, col)
    
    # Create color array: black background, cluster in its unique color
    colors_single = np.zeros((num_vertices, 3))
    cluster_id = cluster_info['cluster_id']
    vertex_indices = cluster_info['vertices']
    
    # Color the cluster vertices
    colors_single[vertex_indices] = cluster_colors[cluster_id][:3]
    
    # Add colors to mesh
    pv_mesh_copy = pv_mesh.copy()
    pv_mesh_copy['colors'] = colors_single
    
    # Add to plot
    plotter.add_mesh(pv_mesh_copy, rgb=True, scalars='colors')
    plotter.add_text(f"{cluster_info['category']} - Cluster {cluster_id}", 
                     font_size=10, position='upper_edge')
    plotter.camera_position = 'iso'

print("Rendering grid visualization...")

In [None]:
# Show the grid visualization
plotter.show()

# Optional: Save as image
# output_image_path = mesh_dir / "clusters_grid.png"
# plotter.screenshot(output_image_path)
# print(f"Saved grid visualization to {output_image_path}")

### Cluster Summary

Display a summary of all clusters organized by category:

In [None]:
# Print cluster summary
print("=" * 60)
print("CLUSTER SUMMARY")
print("=" * 60)

for category, category_clusters in all_clusters.items():
    print(f"\n{category.upper()}: {len(category_clusters)} cluster(s)")
    for cluster_id, vertex_indices in category_clusters:
        color_rgb = [int(c * 255) for c in cluster_colors[cluster_id][:3]]
        print(f"  Cluster {cluster_id}: {len(vertex_indices)} vertices | Color: RGB{tuple(color_rgb)}")

print("\n" + "=" * 60)
print(f"Total: {cluster_counter} clusters across {len(all_clusters)} categories")
print(f"Total clustered vertices: {num_clustered}/{num_vertices} ({100*num_clustered/num_vertices:.1f}%)")
print("=" * 60)

### Alternative: Save Grid as Static Image

Optionally save the grid visualization as a static PNG image:

In [None]:
# Create offline plotter to render to image
plotter_offline = pv.Plotter(shape=(n_rows, n_cols), 
                              window_size=[1600, 400 * n_rows],
                              off_screen=True)

# Recreate the grid
for idx, cluster_info in enumerate(all_clusters_flat):
    row = idx // n_cols
    col = idx % n_cols
    
    plotter_offline.subplot(row, col)
    
    # Create color array
    colors_single = np.zeros((num_vertices, 3))
    cluster_id = cluster_info['cluster_id']
    vertex_indices = cluster_info['vertices']
    colors_single[vertex_indices] = cluster_colors[cluster_id][:3]
    
    # Add to plot
    pv_mesh_copy = pv_mesh.copy()
    pv_mesh_copy['colors'] = colors_single
    plotter_offline.add_mesh(pv_mesh_copy, rgb=True, scalars='colors')
    plotter_offline.add_text(f"{cluster_info['category']} - Cluster {cluster_id}", 
                            font_size=10, position='upper_edge')
    plotter_offline.camera_position = 'iso'

# Save as image
output_image_path = mesh_dir / "clusters_grid.png"
plotter_offline.screenshot(output_image_path)
plotter_offline.close()

print(f"Saved grid visualization to {output_image_path}")