# K-Means Clustering
Cluster splats using K-Means clustering.
This is the offline "oracle" version of the clustering.

In [1]:
# Path to the CSV file containing the collected splats.
CSV_FILE_PATH = "data/collected_splats.csv"

# Number of clusters.
CLUSTERS = 8

In [2]:
import numpy as np
import polars as pl

from constants import IMAGE_HEIGHT, IMAGE_WIDTH


In [3]:
# Define the column names.
column_names = [f"gaussian_{i}_{part}" for i in range(500) for part in
                ["alpha", "depth", "color_r", "color_g", "color_b"]]
column_names = ["sample_index", "out_color_r", "out_color_g", "out_color_b", "background_r", "background_g",
                "background_b"] + column_names

# Define schema.
schema_dict = {name: pl.Float32 for name in column_names}
schema_dict["sample_index"] = pl.UInt32
schema_dict["background_r"] = pl.UInt8
schema_dict["background_g"] = pl.UInt8
schema_dict["background_b"] = pl.UInt8
schema = pl.Schema(schema_dict)

In [4]:
# Load data.
data = pl.scan_csv(CSV_FILE_PATH, schema=schema)

## Extract the Splats from the data


In [5]:
# Collect the raw splats.
raw_splats = data.select(
    pl.all().exclude("sample_index", "out_color_r", "out_color_g", "out_color_b", "background_r", "background_g",
                     "background_b")).collect().to_numpy()

# Reshape into splats table per pixel.
splats = raw_splats.reshape((IMAGE_HEIGHT * IMAGE_WIDTH, raw_splats.shape[1] // 5, 5))


# Cluster Splats
For each pixel,
1. Run k-means clustering on all splats.
2. Within each cluster, commutatively sum the alpha and color values.

In [7]:
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import os
from sklearn.cluster import KMeans

In [8]:
# Restrict the number of threads to 1 for kmeans.
os.environ['OMP_NUM_THREADS'] = '1'

In [None]:
def cluster_pixel(pixel_splats, clusters):
    """Compute clustering on a single pixel."""
    
    # Get depth values.
    depths = pixel_splats[:, 1]
    
    # Run K-Means clustering.
    kmeans = KMeans(n_clusters=clusters).fit(depths.reshape(-1, 1))
    
    # Initialize clustering 2D list: cluster -> [[alpha, depth, *color], ...].
    pixel_clustering = [[] for _ in range(clusters)]
    
    # Loop through each splat and place it in the appropriate cluster.
    for splat_index, cluster_index in enumerate(kmeans.labels_):
        pixel_clustering[cluster_index].append(pixel_splats[splat_index])
    
    # Sum the alpha and color values within each cluster.