# Epsilon Clustering
Cluster splats that are within an epsilon distance of each other.

In [1]:
from clustering_exploration.basic_viewer import transmittance

# Path to the CSV file containing the collected splats.
CSV_FILE_PATH = "data/collected_splats.csv"

# Epsilon distance.
EPSILON = 0.11

In [2]:
import numpy as np
import polars as pl
from constants import IMAGE_HEIGHT, IMAGE_WIDTH

## Load Data From CSV

In [3]:
# Define the column names.
column_names = [f"gaussian_{i}_{part}" for i in range(500) for part in
                ["alpha", "depth", "color_r", "color_g", "color_b"]]
column_names = ["sample_index", "out_color_r", "out_color_g", "out_color_b", "background_r", "background_g",
                "background_b"] + column_names

# Define schema.
schema_dict = {name: pl.Float32 for name in column_names}
schema_dict["sample_index"] = pl.UInt32
schema_dict["background_r"] = pl.UInt8
schema_dict["background_g"] = pl.UInt8
schema_dict["background_b"] = pl.UInt8
schema = pl.Schema(schema_dict)

In [4]:
# Load data.
data = pl.scan_csv(CSV_FILE_PATH, schema=schema)

## Extract the Splats from the data

In [5]:
# Collect the raw splats.
raw_splats = data.select(
    pl.all().exclude("sample_index", "out_color_r", "out_color_g", "out_color_b", "background_r", "background_g",
                     "background_b")).collect().to_numpy()

# Reshape into splats table per pixel.
splats = raw_splats.reshape((IMAGE_HEIGHT * IMAGE_WIDTH, raw_splats.shape[1] // 5, 5))

# Cluster Splats
For each pixel, take a splat and
1. If there are no clusters, create a new dictionary entry with the splat's depth as the key and the alpha and color as the value.
2. If there are clusters, find the closest cluster by EPSILON distance and add it to that splats dictionary entry.
3. If 2 fails, create a new dictionary entry. 

In [6]:
from tqdm.auto import tqdm
from joblib import Parallel, delayed

In [None]:
def cluster_splats(splats_to_cluster, epsilon):
    """Cluster splats that are within an epsilon distance of each other."""

    def cluster_pixel(pixel_splats):
        # Define clustering dictionary: depth -> [(alpha, color)].
        pixel_clustering = {}

        # Loop through each splat.
        for splat in pixel_splats:
            # Skip transparent splats.
            if splat[0] == 0:
                continue

            # Case 1: No clusters.
            if not pixel_clustering:
                pixel_clustering[splat[1]] = [(splat[0], splat[2:])]
            # Case 2: Cluster is within epsilon distance.
            elif abs(splat[1] - min(pixel_clustering.keys(), key=lambda depth: abs(depth - splat[1]))) < epsilon:
                pixel_clustering[splat[1]].append((splat[0], splat[2:]))
            # Case 3: 2 fails, create a new cluster.
            else:
                pixel_clustering[splat[1]] = [(splat[0], splat[2:])]

        # Sort the clusters by depth.
        pixel_clustering = dict(sorted(pixel_clustering.items()))

        # Alpha compose the splats in a clusters.
        pixel_output = np.zeros((len(pixel_clustering), 4))
        for index, cluster in enumerate(pixel_clustering.values()):
            transmittance = 1
            color = np.zeros(3)

            # Compose.
            for alpha, splat_color in cluster:
                color += alpha * splat_color
                transmittance *= 1 - alpha

            # Normalize by total alpha.
            cluster_alpha = np.clip(1 - transmittance, 0, 1)
            cluster_color = np.clip(color / cluster_alpha, 0, 1) if cluster_alpha > 0 else np.zeros(3)
            
            # Add to the output.
            pixel_output[index] = np.concatenate(([cluster_alpha], cluster_color))
        return pixel_output
    