# Epsilon Clustering
Cluster splats that are within an epsilon distance of each other.

In [None]:
# Path to the CSV file containing the collected splats.
DATA_NAME = "collected_splats"

# Epsilon distance.
EPSILON = 0.11

In [None]:
import numpy as np
from clustering_exploration.utils.data_handler import load_splats

from clustering_exploration.utils.constants import IMAGE_HEIGHT, IMAGE_WIDTH

In [None]:
splats = load_splats(DATA_NAME)

# Cluster Splats
For each pixel, take a splat and
1. If there are no clusters, create a new dictionary entry with the splat's depth as the key and the alpha and color as the value.
2. If there are clusters, find the closest cluster by EPSILON distance and add it to that splats dictionary entry.
3. If 2 fails, create a new dictionary entry. 

In [None]:
from tqdm.auto import tqdm
from joblib import Parallel, delayed

In [None]:
def cluster_pixel(pixel_splats, epsilon):
    """Compute clustering on a single pixel."""

    # Define clustering dictionary: depth -> [(alpha, color)].
    pixel_clustering = {}

    # Loop through each splat.
    for splat in pixel_splats:
        splat_alpha, splat_depth, *splat_color = splat
        splat_color = np.array(splat_color)
        # Skip transparent splats.
        if splat_alpha == 0:
            continue

        combined_splat_info = np.concatenate(([splat_alpha], splat_color))

        # Case 1: No clusters.
        if not pixel_clustering:
            pixel_clustering[splat_depth] = [combined_splat_info]
            continue

        # Case 2: Cluster is within epsilon distance.
        closest_depth = min(pixel_clustering.keys(), key=lambda depth: abs(depth - splat_depth))
        if abs(splat_depth - closest_depth) <= epsilon:
            pixel_clustering[closest_depth].append(combined_splat_info)
        # Case 3: case 2 fails.
        else:
            pixel_clustering[splat_depth] = [combined_splat_info]

    # Sort the clusters by depth.
    pixel_clustering = dict(sorted(pixel_clustering.items()))

    # Commutative combination of the splats in each cluster (alpha, color).
    pixel_output = np.zeros((len(pixel_clustering), 4))
    for index, cluster in enumerate(pixel_clustering.values()):
        cluster = np.array(cluster)
        pixel_output[index, 0] = 1 - np.prod(1 - cluster[:, 0])
        alpha_sum = np.sum(cluster[:, 0])
        if alpha_sum:
            pixel_output[index, 1:] = np.sum(cluster[:, 0].reshape(-1, 1) * cluster[:, 1:], axis=0) / alpha_sum
        else:
            pixel_output[index, 1:] = np.zeros(3)

    # Return the clustered pixel.
    return pixel_output

## Do the Clustering
1. Shuffle the splats in each pixel.
2. Cluster the shuffled splats.

In [None]:
# Shuffle the splats in each pixel.
rng = np.random.default_rng()
rng.shuffle(splats, axis=1)

In [None]:
# Cluster shuffled splats.
clustered_splats = Parallel(n_jobs=-1)(delayed(cluster_pixel)(pixel_splats, EPSILON) for pixel_splats in tqdm(splats))

# Compute Image From Clusters
Compute the final pixel color by alpha compositing the clusters.

In [None]:
def alpha_compose_pixel(pixel_clusters):
    # Define the transmittance and pixel color for the first cluster.
    transmittance = 1.0
    pixel_color = np.zeros(3)

    # Loop through remaining clusters.
    for cluster_alpha, *cluster_color in pixel_clusters:
        # Skip transparent cluster.
        if not cluster_alpha:
            continue

        # Exit once the transmittance is basically zero.
        if transmittance <= 0.001:
            break

        # Compute the pixel color.
        pixel_color += cluster_alpha * np.array(cluster_color) * transmittance

        # Compute the transmittance.
        transmittance *= 1 - min(1, cluster_alpha)

    # Return the computed pixel color.
    return pixel_color

## Do the Computation

In [None]:
from clustering_exploration.utils.image_handler import alpha_compose_splats

In [None]:
computed_image = [alpha_compose_splats(pixel_clusters) for pixel_clusters in tqdm(clustered_splats)]

## Display the computed image.

In [None]:
from clustering_exploration.utils.image_handler import save_array_to_image
display(save_array_to_image(np.array(computed_image).reshape(IMAGE_HEIGHT, IMAGE_WIDTH, 3), "epsilon_clustering_1"))