# Compression Pipeline:


<div style="background-color: white; padding: 10px;">
    <img src="./docs/static/img/pipeline.svg" alt="SVG Image" width="1500px" />
</div>

Instead of running the compression pipeline all at once, here you can run it step-by-step and explore the process.

## First simulate the command-line arguments

In [None]:
simulated_args = [
        "--model_path", "./input_models/flower_hq",
        "--data_device", "cuda",
        "--output_vq", "./output"
]

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import time

from argparse import ArgumentParser
from arguments import (
    CompressionParams,
    ModelParams,
    OptimizationParams,
    PipelineParams,
    get_combined_args,
)

from compress import unique_output_folder, calc_importance
from gaussian_renderer import GaussianModel
from scene import Scene
from compression.vq import CompressionSettings, compress_color, compress_covariance
from typing import Tuple, Optional

In [None]:
def parse_arguments(simulated_args=[]):
    # Initialize the argument parser
    parser = ArgumentParser(description="Compression script parameters")
    
    # Add the same argument groups as in the script
    model = ModelParams(parser, sentinel=True)
    model.data_device = "cuda"
    pipeline = PipelineParams(parser)
    op = OptimizationParams(parser)
    comp = CompressionParams(parser)
    
    # Combine simulated args with parser arguments
    args = get_combined_args(parser, simulated_args)
    return args, model, pipeline, op, comp


args, model, pipeline, op, comp = parse_arguments(simulated_args)

# Set output folder if not specified
if args.output_vq is None:
    args.output_vq = unique_output_folder()

# Extract parameters
model_params = model.extract(args)
optim_params = op.extract(args)
pipeline_params = pipeline.extract(args)
comp_params = comp.extract(args)

In [None]:
# Initialize the Gaussians
gaussians = GaussianModel(
    model_params.sh_degree, quantization=not optim_params.not_quantization_aware
)

# Initialize the scene (test cameras + train cameras)
scene = Scene(
    model_params, gaussians, load_iteration=comp_params.load_iteration, shuffle=True
)

# Extract the Gaussians from the pre-trained model (checkpoint)
if comp_params.start_checkpoint:
    (checkpoint_params, first_iter) = torch.load(comp_params.start_checkpoint)
    gaussians.restore(checkpoint_params, optim_params)


timings ={}

## Step 1: Parameter Sensitivity
Note: The authors use 'sensitivity' and 'importance' interchangeably, this is very confusing I know 

In [None]:
start_time = time.time()

color_importance, gaussian_sensitivity = calc_importance(
    gaussians, scene, pipeline_params
)
end_time = time.time()
timings["sensitivity_calculation"] = end_time-start_time

In [None]:
color_importance_include = torch.tensor(0.6 * 1e-6)
gaussian_importance_include = torch.tensor(0.3 * 1e-5)

color_above_threshold = (color_importance > color_importance_include).sum().item()
total_elements_color = color_importance.numel()

gaussian_above_threshold = (gaussian_sensitivity > gaussian_importance_include).sum().item()
total_elements_gaussian = gaussian_sensitivity.numel()

color_threshold = 1.0 - (color_above_threshold / total_elements_color)
gaussian_threshold = 1.0 - (gaussian_above_threshold / total_elements_gaussian)

print(f"Percentage of color_importance values below the threshold: {color_threshold:.2f}%")
print(f"Percentage of gaussian_importance values below the threshold: {gaussian_threshold:.2f}%")

In [None]:
# Normalize the tensors
color_importance_norm = torch.nn.functional.normalize(color_importance.clone(), p=2)
gaussian_sensitivity_norm = torch.nn.functional.normalize(gaussian_sensitivity.clone(), p=2)

In [None]:
color_importance_norm_np = color_importance_norm.cpu().numpy().flatten()  # Convert to numpy array if needed
gaussian_sensitivity_norm_np = gaussian_sensitivity_norm.cpu().numpy().flatten()
# color_threshold = color_importance_include.cpu().numpy()
# gaussian_threshold = gaussian_importance_include.cpu().numpy()

# color_importance = color_importance.flatten()  # Convert to numpy array if needed
# gaussian_sensitivity = gaussian_sensitivity.flatten()


# Define the number of bins
num_bins = 20

# Plotting
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

# Color sensitivity histogram
axes[0].hist(
    color_importance_norm_np,
    bins=num_bins,
    color="#1f77b4",
    density=True
)
axes[0].axvline(color_threshold, color='red', linestyle='--', label=f'Threshold ({color_threshold})')
axes[0].set_title("Color Sensitivity Distribution")
axes[0].set_xlabel("Sensitivity")
axes[0].set_ylabel("Density")
# axes[0].set_yscale("log")

# Shape sensitivity histogram
axes[1].hist(
    gaussian_sensitivity_norm_np,
    bins=num_bins,
    color="#ff7f0e",
    density=True
)
axes[1].axvline(gaussian_threshold, color='red', linestyle='--', label=f'Threshold ({gaussian_threshold})')
axes[1].set_title("Shape Sensitivity Distribution")
axes[1].set_xlabel("Sensitivity")

plt.tight_layout()
plt.show()

## Step 2: Sensitivity-aware vector clustering (K-Means) 

In [None]:
import plotly.graph_objects as go

def prune_gaussians(
    gaussians: GaussianModel,
    color_importance: torch.Tensor,
    color_importance_n_norm: torch.Tensor,
    gaussian_importance: torch.Tensor,
    prune_threshold:float=0.,
):
    with torch.no_grad():
        if prune_threshold >= 0:
            non_prune_mask = color_importance > prune_threshold
            print(f"prune: {(1-non_prune_mask.float().mean())*100:.2f}%")
            # gaussians.mask_splats(non_prune_mask)
            # gaussian_importance = gaussian_importance[non_prune_mask]
            # color_importance = color_importance[non_prune_mask]

            # Example positions for each Gaussian (assuming each Gaussian has a 3D position)
            positions = gaussians.get_xyz
            positions_np = positions.cpu().numpy()

            # Before pruning (all Gaussians)
            color_importance_n_np = color_importance_n_norm.cpu().numpy()

            # # Apply pruning to get the mask
            # non_prune_mask = color_importance_n > prune_threshold
            # non_prune_mask_np = non_prune_mask.cpu().numpy()

            # # Scatter plot
            # fig = plt.figure(figsize=(12, 6))

            # # Before Pruning
            # ax1 = fig.add_subplot(121, projection='3d')
            # ax1.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
            #             c=color_importance_n_np, cmap='viridis', s=1)
            # ax1.set_title("Before Pruning")
            # ax1.set_xlabel("X")
            # ax1.set_ylabel("Y")
            # ax1.set_zlabel("Z")

            # # # After Pruning (Only non-pruned Gaussians)
            # # ax2 = fig.add_subplot(122, projection='3d')
            # # ax2.scatter(positions[non_prune_mask_np, 0], positions[non_prune_mask_np, 1], positions[non_prune_mask_np, 2], 
            # #             c=color_importance_n_np[non_prune_mask_np], cmap='viridis', s=1)
            # # ax2.set_title("After Pruning")
            # # ax2.set_xlabel("X")
            # # ax2.set_ylabel("Y")
            # # ax2.set_zlabel("Z")

            # plt.show()

            # Create the 3D scatter plot
            fig = go.Figure(data=[go.Scatter3d(
                x=positions_np[:, 0],
                y=positions_np[:, 1],
                z=positions_np[:, 2],
                mode='markers',
                marker=dict(
                    size=3,
                    color=color_importance_n_np,    # Color by sensitivity
                    colorscale='Viridis',        # Color scale
                    opacity=0.7,
                    colorbar=dict(title="Sensitivity")
                )
            )])

            fig.update_layout(
                title="Gaussian Positions Before Pruning",
                scene=dict(
                    xaxis_title="X",
                    yaxis_title="Y",
                    zaxis_title="Z",
                )
            )

            fig.show()


In [None]:
with torch.no_grad():
    start_time = time.time()

    # Given a vector x ∈ R^D, we define its sensitivity as the maximum over its component’s sensitivity
    color_importance_n = color_importance.amax(-1)
    gaussian_importance_n = gaussian_sensitivity.amax(-1)

    color_importance_n_norm = color_importance_norm.amax(-1)

    torch.cuda.empty_cache()

    # Initialize the color codebook
    color_compression_settings = CompressionSettings(
        codebook_size=comp_params.color_codebook_size,              # K = number of centroids = codebook size
        importance_prune=comp_params.color_importance_prune,
        importance_include=comp_params.color_importance_include,
        steps=int(comp_params.color_cluster_iterations),
        decay=comp_params.color_decay,
        batch_size=comp_params.color_batch_size,
    )

    # Initialize the Gaussian shape codebook
    gaussian_compression_settings = CompressionSettings(
        codebook_size=comp_params.gaussian_codebook_size,
        importance_prune=None,
        importance_include=comp_params.gaussian_importance_include,
        steps=int(comp_params.gaussian_cluster_iterations),
        decay=comp_params.gaussian_decay,
        batch_size=comp_params.gaussian_batch_size,
    )

    prune_gaussians(
        gaussians,
        color_importance_n,
        color_importance_n_norm,
        gaussian_importance_n,
        prune_threshold=comp_params.prune_threshold,
    )

    end_time = time.time()
    timings["clustering"]=end_time-start_time