# Compression Pipeline:


<div style="background-color: white; padding: 10px;">
    <img src="./docs/static/img/pipeline.svg" alt="SVG Image" width="1500px" />
</div>

Instead of running the compression pipeline all at once, here you can run it step-by-step and explore the process.

In [None]:
# General
import time
import os
from os import path
from shutil import copyfile
import gc
import json
from random import randint

# Tensors
import torch
import numpy as np
from tqdm import tqdm, trange
from scipy.spatial import cKDTree

# Visualisations
from visualisation.plots import *
from visualisation.utils import *

# c3dgs functions / classes
from argparse import Namespace
from compress import unique_output_folder, calc_importance, render_and_eval
from gaussian_renderer import GaussianModel
from scene import Scene
from compression.vq import CompressionSettings
from typing import Tuple
from compression.vq import VectorQuantize, join_features
from finetune import prepare_output_and_logger
from gaussian_renderer import render
from utils.loss_utils import l1_loss, ssim


## First simulate the command-line arguments

In [None]:
simulated_args = [
        "--model_path", "./input_models/flower_hq",
        "--data_device", "cuda",
        "--output_vq", "./output"
]

In [None]:
args, model, pipeline, op, comp = parse_arguments(simulated_args)

# Set output folder if not specified
if args.output_vq is None:
    args.output_vq = unique_output_folder()

# Extract parameters
model_params = model.extract(args)
optim_params = op.extract(args)
pipeline_params = pipeline.extract(args)
comp_params = comp.extract(args)

In [None]:
# Initialize the Gaussians
gaussians = GaussianModel(
    model_params.sh_degree, quantization=not optim_params.not_quantization_aware
)

# Initialize the scene (test cameras and/or train cameras)
scene = Scene(
    model_params, gaussians, load_iteration=comp_params.load_iteration, shuffle=True
)

# Load the Gaussians from the pre-trained model (checkpoint) into memory
if comp_params.start_checkpoint:
    (checkpoint_params, first_iter) = torch.load(comp_params.start_checkpoint)
    gaussians.restore(checkpoint_params, optim_params)


timings = {}
starting_gaussians = int(gaussians.get_xyz.shape[0])

## Step 1: Parameter Sensitivity
Note: The authors use 'sensitivity' and 'importance' interchangeably, this is very confusing I know 

In [None]:
# Important hyperparameters                                     # Default value
comp_params.color_importance_include                            # 0.6*1e-6
comp_params.gaussian_importance_include                         # 0.3*1e-5

In [None]:
start_time = time.time()

color_importance, gaussian_sensitivity = calc_importance(
    gaussians, scene, pipeline_params
)

end_time = time.time()
timings["sensitivity_calculation"] = end_time-start_time

In [None]:
color_importance_include = torch.tensor(comp_params.color_importance_include)
gaussian_importance_include = torch.tensor(comp_params.gaussian_importance_include)

color_above_threshold = (color_importance > color_importance_include).sum().item()
total_elements_color = color_importance.numel()

gaussian_above_threshold = (gaussian_sensitivity > gaussian_importance_include).sum().item()
total_elements_gaussian = gaussian_sensitivity.numel()

color_threshold = 1.0 - (color_above_threshold / total_elements_color)
gaussian_threshold = 1.0 - (gaussian_above_threshold / total_elements_gaussian)

print(f"Percentage of color_importance values below the threshold: {color_threshold * 100:.2f}%")
print(f"Percentage of gaussian_importance values below the threshold: {gaussian_threshold * 100:.2f}%")

# The Gaussians above the threshold will not be used during clustering, instead they are stored separately

In [None]:
# Normalize the tensors
color_importance_norm = torch.nn.functional.normalize(color_importance.clone(), p=2).flatten()
gaussian_sensitivity_norm = torch.nn.functional.normalize(gaussian_sensitivity.clone(), p=2).flatten()

color_shape_sensitivity_hist(color_importance_norm, gaussian_sensitivity_norm, color_threshold, gaussian_threshold)

## Step 2: Sensitivity-aware vector clustering
Note: vector clustering = vector quantization = K-Means clustering

------------------ Pruning ------------------

In [None]:
# Important hyperparameters                                     # Default value
comp_params.prune_threshold                                     # 0.0
lambda_r = 1.0                                                  # 1.0

In [None]:
# Pruning flags                                                 # Default value
apply_sens_pruning = True                                       # True
apply_dens_pruning = False                                      # False

In [None]:
with torch.no_grad():
    start_time = time.time()

    # Compute the maximum sensitivity over each component
    color_importance_n = color_importance.amax(-1)
    gaussian_importance_n = gaussian_sensitivity.amax(-1)

    torch.cuda.empty_cache()

    # Sensitivity-based pruning threshold
    prune_threshold = comp_params.prune_threshold

    # Initial mask (all True if no pruning is applied)
    final_mask = torch.ones(gaussians.get_xyz.shape[0], dtype=torch.bool, device=gaussians.get_xyz.device)
    
    print('#Gaussians before pruning:', gaussians.get_xyz.shape[0])
    # Step 1: Sensitivity-based Pruning
    if apply_sens_pruning and prune_threshold >= 0:

        # Sensitivity-based Pruning Mask
        non_prune_mask_sensitivity = color_importance_n > prune_threshold
        final_mask &= non_prune_mask_sensitivity  # Apply sensitivity mask to final_mask

        print('#Gaussians after sensitivity pruning:', non_prune_mask_sensitivity.sum().item())

    # Step 2: Spatial Density Pruning
    if apply_dens_pruning:
        # Apply density pruning on already sensitivity-pruned positions
        positions = gaussians.get_xyz[final_mask]
        kdtree = cKDTree(positions.detach().cpu())
        radius = determine_radius(scene, gaussians)

        # Count neighbors within radius for each Gaussian in pruned positions
        neighbor_counts = [
            len(kdtree.query_ball_point(pos.detach().cpu(), radius)) - 1  # Exclude self
            for pos in positions
        ]

        # Redundancy scores as a PyTorch tensor
        redundancy_scores = torch.tensor(neighbor_counts, device=gaussians.get_xyz.device)

        # Adaptive redundancy threshold
        mean_score = redundancy_scores.float().mean()
        std_dev_score = redundancy_scores.float().std()
        tau_p = torch.max(mean_score + lambda_r * std_dev_score, torch.tensor(3.0))

        # Redundancy-based pruning mask
        non_prune_mask_redundancy = redundancy_scores <= tau_p

        # Combine redundancy mask with sensitivity-pruned mask
        updated_final_mask = torch.zeros_like(final_mask)
        updated_final_mask[final_mask] = non_prune_mask_redundancy  # Apply redundancy mask
        final_mask = updated_final_mask

        print('#Gaussians after Spatial Density Pruning:', final_mask.sum().item())

    # Apply final mask to prune Gaussians
    pos_keep = gaussians.get_xyz[final_mask].cpu().numpy()
    pos_prune = gaussians.get_xyz[~final_mask].cpu().numpy()

    gaussians.mask_splats(final_mask)
    gaussian_importance_n = gaussian_importance_n[final_mask]
    color_importance_n = color_importance_n[final_mask]

    end_time = time.time()
    timings["pruning"] = end_time-start_time

Gaussians with sufficiently low sensitivity will be removed from the scene. 
These Gaussians are visualised in red:

In [None]:
scatterplot_prune_gaussians(pos_prune, pos_keep, final_mask)

------------------ Color Compression ------------------

In [None]:
# Important hyperparameters                                     # Default value
comp_params.color_codebook_size                                 # 2**12
# comp_params.color_cluster_iterations                            # 100
comp_params.color_cluster_iterations = 1 # TODO: remove
comp_params.color_decay                                         # 0.8
comp_params.color_batch_size                                    # 2**18
comp_params.color_compress_non_dir                              # True

color_use_kmeanspp_init = False                                 # False

# Initialize the color codebook using parameters
color_compression_settings = CompressionSettings(
    codebook_size=comp_params.color_codebook_size,
    importance_prune=comp_params.color_importance_prune,
    importance_include=comp_params.color_importance_include,
    steps=int(comp_params.color_cluster_iterations),
    decay=comp_params.color_decay,
    batch_size=comp_params.color_batch_size,
)

In [None]:
# remove zero sh component to get the color features
if comp_params.color_compress_non_dir:
    n_sh_coefs = gaussians.get_features.shape[1]
    color_features = gaussians.get_features.detach().flatten(-2)
else:
    n_sh_coefs = gaussians.get_features.shape[1] - 1
    color_features = gaussians.get_features[:, 1:].detach().flatten(-2)

Plot the initial Color Feature Space:

In [None]:
plot_features_pca(color_features, label="Color Features")

In [None]:
def vq_features_vis(
    features: torch.Tensor,
    importance: torch.Tensor,
    codebook_size: int,
    vq_chunk: int = 2**16,
    steps: int = 1000,
    decay: float = 0.8,
    scale_normalize: bool = False,
    use_kmeanspp_init: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    importance_n = importance/importance.max()
    vq_model = VectorQuantize(
        channels=features.shape[-1],
        codebook_size=codebook_size,
        decay=decay,
    ).to(device=features.device)

    # Use random initialization (default) or kmeans++ initialization based on flag
    if use_kmeanspp_init:
        print("using kmeans++ init")
        vq_model.kmeanspp_init_batch(features) # k-means++ initialization
    else:
        print("using random init (default)")
        vq_model.uniform_init(features) # Random initialization
        
    errors = []
    centroids_history = [vq_model.codebook.data.cpu().numpy().copy()] # Store the initial centroids

    for i in trange(steps):
        batch = torch.randint(low=0, high=features.shape[0], size=[vq_chunk])
        vq_feature = features[batch]
        error = vq_model.update(vq_feature, importance=importance_n[batch]).mean().item()
        errors.append(error)

        # Store centroids every 5th iteration
        if (i + 1) % 5 == 0:
            centroids_history.append(vq_model.codebook.data.cpu().numpy().copy())

        if scale_normalize:
            # this computes the trace of the codebook covariance matrices
            # we devide by the trace to ensure that matrices have normalized eigenvalues / scales
            tr = vq_model.codebook[:, [0, 3, 5]].sum(-1)
            vq_model.codebook /= tr[:, None]

    gc.collect()
    torch.cuda.empty_cache()

    start = time.time()
    _, vq_indices = vq_model(features)
    torch.cuda.synchronize(device=vq_indices.device)
    end = time.time()
    print(f"calculating indices took {end-start} seconds ")
    return vq_model.codebook.data.detach(), vq_indices.detach(), errors, centroids_history

In [None]:
def compress_color_vis(
    gaussians: GaussianModel,
    color_importance_n: torch.Tensor,
    color_features: torch.Tensor,
    color_comp: CompressionSettings,
    use_kmeanspp_init: bool = False,
):
    keep_mask = color_importance_n > color_comp.importance_include

    print(f"color keep: {keep_mask.float().mean()*100:.2f}%")

    vq_mask_c = ~keep_mask

    if vq_mask_c.any():
        color_codebook, color_vq_indices, errors, centroids_history = vq_features_vis(
            color_features[vq_mask_c],
            color_importance_n[vq_mask_c],
            color_comp.codebook_size,
            color_comp.batch_size,
            color_comp.steps,
            use_kmeanspp_init=use_kmeanspp_init,
        )
    else:
        color_codebook = torch.empty(
            (0, color_features.shape[-1]), device=color_features.device
        )
        color_vq_indices = torch.empty(
            (0,), device=color_features.device, dtype=torch.long
        )

    all_features = color_features
    compressed_features, indices = join_features(
        all_features, keep_mask, color_codebook, color_vq_indices
    )

    gaussians.set_color_indexed(compressed_features.reshape(-1, n_sh_coefs, 3), indices)

    return errors, centroids_history,compressed_features

In [None]:
color_errors = []

with torch.no_grad():

    start_time = time.time()

    color_comp = color_compression_settings if not comp_params.not_compress_color else None
    if color_comp is not None:
        color_errors, color_centroids_history, color_compressed_features = compress_color_vis(
            gaussians,
            color_importance_n,
            color_features,
            color_comp,
            use_kmeanspp_init=color_use_kmeanspp_init
        )

    end_time = time.time()
    timings["color clustering"]=end_time-start_time

Animation of the centroid positions over time:

In [None]:
%matplotlib notebook
ani = animate_feature_clustering(color_features, color_centroids_history, title="2D Projection of Color Features")

Final Result - Initial Color Features (blue) vs Compressed Color Features (red):

In [None]:
%matplotlib inline
plot_features_and_compressed(color_features, color_compressed_features, title="2D Projection of All Features and Compressed Features")

In [None]:
plot_error_curve(color_errors)

------------------ Gaussian Shape Compression ------------------

In [None]:
# Important hyperparameters                                     # Default value
comp_params.gaussian_codebook_size                              # 2**12
# comp_params.gaussian_cluster_iterations                         # 800
comp_params.gaussian_cluster_iterations = 1 # TODO: remove
comp_params.gaussian_decay                                      # 0.8
comp_params.gaussian_batch_size                                 # 2**20

gaussian_use_kmeanspp_init = False                              # False

# Initialize the Gaussian shape codebook using parameters
gaussian_compression_settings = CompressionSettings(
    codebook_size=comp_params.gaussian_codebook_size,
    importance_prune=None,
    importance_include=comp_params.gaussian_importance_include,
    steps=int(comp_params.gaussian_cluster_iterations),
    decay=comp_params.gaussian_decay,
    batch_size=comp_params.gaussian_batch_size,
)

In [None]:
# For the Gaussian shape features, we use the normalized covariance matrix
gaussian_shape_features = gaussians.get_normalized_covariance(strip_sym=True).detach()

gaussian_shape_features_plot = gaussians.get_normalized_covariance(strip_sym=False).detach() # Symmetry required for matrix decomposition
rot_plot, scale_plot = extract_rot_scale(gaussian_shape_features_plot)

In [None]:
plot_features_pca(rot_plot)
plot_features_pca(scale_plot)
plot_features_3d(scale_plot, title="", elev=40, azim=130)

In [None]:
def compress_covariance_vis(
    gaussians: GaussianModel,
    gaussian_importance_n: torch.Tensor,
    gaussian_shape_features: torch.Tensor,
    gaussian_comp: CompressionSettings,
    use_kmeanspp_init: bool = False
):

    keep_mask_g = gaussian_importance_n > gaussian_comp.importance_include

    vq_mask_g = ~keep_mask_g

    print(f"gaussians keep: {keep_mask_g.float().mean()*100:.2f}%")

    if vq_mask_g.any():
        cov_codebook, cov_vq_indices, errors, centroids_history = vq_features_vis(
            gaussian_shape_features[vq_mask_g],
            gaussian_importance_n[vq_mask_g],
            gaussian_comp.codebook_size,
            gaussian_comp.batch_size,
            gaussian_comp.steps,
            scale_normalize=True,
            use_kmeanspp_init=use_kmeanspp_init
        )
    else:
        cov_codebook = torch.empty(
            (0, gaussian_shape_features.shape[1], 1), device=gaussian_shape_features.device
        )
        cov_vq_indices = torch.empty((0,), device=gaussian_shape_features.device, dtype=torch.long)

    compressed_cov, cov_indices = join_features(
        gaussian_shape_features,
        keep_mask_g,
        cov_codebook,
        cov_vq_indices,
    )

    rot_vq, scale_vq = extract_rot_scale(to_full_cov(compressed_cov))

    gaussians.set_gaussian_indexed(
        rot_vq.to(compressed_cov.device),
        scale_vq.to(compressed_cov.device),
        cov_indices,
    )
    
    return errors, centroids_history, rot_vq, scale_vq

In [None]:
shape_errors = []

with torch.no_grad():

    start_time = time.time()

    gaussian_comp = gaussian_compression_settings if not comp_params.not_compress_gaussians else None
    if gaussian_comp is not None:
        shape_errors, shape_centroids_history, rot_vq, scale_vq = compress_covariance_vis(
            gaussians,
            gaussian_importance_n,
            gaussian_shape_features,
            gaussian_comp,
            use_kmeanspp_init=gaussian_use_kmeanspp_init
        )

    end_time = time.time()
    timings["shape clustering"]=end_time-start_time

In [None]:
%matplotlib notebook
rot_history, scale_history = convert_centroids_to_rot_scale(shape_centroids_history)

ani = animate_feature_clustering(rot_plot, rot_history, title="2D Projection of Gaussian Shape Features")
ani = animate_feature_clustering(scale_plot, scale_history, title="2D Projection of Gaussian Shape Features")
ani = animate_feature_clustering_3d(scale_plot, scale_history, title="2D Projection of Gaussian Shape Features")

^ Shows the high redundancy in shape features

In [None]:
%matplotlib inline
plot_features_and_compressed(rot_plot, rot_vq)
plot_features_and_compressed(scale_plot, scale_vq)
plot_features_and_compressed_3d(scale_plot, scale_vq)

In [None]:
%matplotlib inline
plot_error_curve(shape_errors)

Before moving on to the finetuning step we have to prepare the output directory

In [None]:
# Clean up memory 
gc.collect()
torch.cuda.empty_cache()

# Create output directory
os.makedirs(comp_params.output_vq, exist_ok=True)

# Copy configuration file
copyfile(
    path.join(model_params.model_path, "cfg_args"),
    path.join(comp_params.output_vq, "cfg_args"),
)

# Update model path to point to new output directory
model_params.model_path = comp_params.output_vq

#  Save compression parameters in a new configuration file
with open(
    os.path.join(comp_params.output_vq, "cfg_args_comp"), "w"
) as cfg_log_f:
    cfg_log_f.write(str(Namespace(**vars(comp_params))))

## Step 3: Quantization-Aware Fine-Tuning

In [None]:
# Important hyperparameters                                     # Default value
# comp_params.finetune_iterations                                 # 5000
comp_params.finetune_iterations = 200
optim_params.lambda_dssim                                       # 0.2

In [None]:
def finetune_vis(scene: Scene, vis_cam, dataset, opt, comp, pipe, testing_iterations, debug_from):
    prepare_output_and_logger(comp.output_vq, dataset)

    first_iter = scene.loaded_iter
    max_iter = first_iter + comp.finetune_iterations

    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    iter_start = torch.cuda.Event(enable_timing=True)
    iter_end = torch.cuda.Event(enable_timing=True)

    scene.gaussians.training_setup(opt)
    scene.gaussians.update_learning_rate(first_iter)

    viewpoint_stack = None
    ema_loss_for_log = 0.0
    progress_bar = tqdm(range(first_iter, max_iter), desc="Training progress")
    first_iter += 1

    # Store intermediate renderings and losses for visualisation
    rendering = render(vis_cam, scene.gaussians, pipe, background)["render"]
    rendered_images = [np.clip(rendering.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)]

    losses = []


    for iteration in range(first_iter, max_iter + 1):
        iter_start.record()

        # Pick a random Camera
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))

        # Render
        if (iteration - 1) == debug_from:
            pipe.debug = True
        
        render_pkg = render(viewpoint_cam, scene.gaussians, pipe, background)
        image, viewspace_point_tensor, visibility_filter, radii = (
            render_pkg["render"],
            render_pkg["viewspace_points"],
            render_pkg["visibility_filter"],
            render_pkg["radii"],
        )

        # Loss
        gt_image = viewpoint_cam.original_image.cuda()

        Ll1 = l1_loss(image, gt_image)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
            1.0 - ssim(image, gt_image)
        )
        loss.backward()

        losses.append(loss.detach().cpu())

        iter_end.record()
        scene.gaussians.update_learning_rate(iteration)

        with torch.no_grad():
            # Progress bar
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
            if iteration % 10 == 0:
                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
                progress_bar.update(10)
            if iteration == max_iter:
                progress_bar.close()

            # Optimizer step
            if iteration < max_iter:
                scene.gaussians.optimizer.step()
                scene.gaussians.optimizer.zero_grad()

            # Visualisation
            if (iteration + 1) % 100 == 0:
                rendering = render(vis_cam, scene.gaussians, pipe, background)["render"]
                rendered_images.append(np.clip(rendering.detach().cpu().numpy().transpose(1, 2, 0), 0, 1))
    
    return rendered_images, losses

In [None]:
iteration = scene.loaded_iter + comp_params.finetune_iterations

if comp_params.finetune_iterations > 0:

    start_time = time.time()

    vis_cam = select_best_camera(scene, model_params, pipeline_params)

    rendered_images, losses = finetune_vis(
        scene,
        vis_cam,
        model_params,
        optim_params,
        comp_params,
        pipeline_params,
        testing_iterations=[-1],
        debug_from=-1,
    )
    
    end_time = time.time()
    timings["finetune"]=end_time-start_time

In [None]:
%matplotlib notebook
ani = animate_training_renders(rendered_images)

In [None]:
%matplotlib inline
     
gt = vis_cam.original_image[0:3, :, :].unsqueeze(0)
gt_np = gt.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

draw_ground_truth_image(gt_np)

In [None]:
%matplotlib inline
plot_finetune_losses(losses, window_size = 100)

## Step 4: Storage

In [None]:
input_model_dir = "input_models/flower_hq"

total_size = sum(os.path.getsize(os.path.join(input_model_dir, f)) for f in os.listdir(input_model_dir) if os.path.isfile(os.path.join(input_model_dir, f)))
input_size = total_size / (1024 ** 2)
print(f"Total size of the input model: {input_size:.2f} MB")

In [None]:
out_file = path.join(
    comp_params.output_vq,
    f"point_cloud/iteration_{iteration}/point_cloud.npz",
)

start_time = time.time()
gaussians.save_npz(out_file, sort_morton=not comp_params.not_sort_morton)
end_time = time.time()

timings["encode"] = end_time-start_time
timings["total"] = sum(timings.values())

with open(f"{comp_params.output_vq}/times.json","w") as f:
    json.dump(timings,f)
file_size = os.path.getsize(out_file) / 1024**2
print(f"saved vq finetuned model to {out_file}")
print(f"File size of the output model = {file_size:.2f}MB")

sizes = [input_size, file_size]

In [None]:
visualise_timings(timings)

In [None]:
visualise_storage_metrics(sizes)

In [None]:
final_gaussians = gaussians.get_xyz.shape[0]

print(starting_gaussians)
print(final_gaussians)

## Evaluation

In [None]:
metrics = render_and_eval(gaussians, scene, model_params, pipeline_params)
metrics["size"] = file_size
print(metrics)
with open(f"{comp_params.output_vq}/results.json","w") as f:
    json.dump({f"ours_{iteration}":metrics},f,indent=4)