# Physics Simulation of 3D Gaussian Splats Using Simplicits - Now With Collisions!

Let's simulate 3D Gaussian Splat objects using [Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/), fully integrated into the Kaolin Library. We will be able to set up and interactively view the simulation directly in this Jupyter notebook.

With v0.18.0, Kaolin also supports collision handling between objects, which we will also show here.

<img src="../../../assets/physics_bulldozer.gif" alt="image info" width="500px"/>

## Installation and Requirements
For splat rendering, we will be relying on a specific version of [INRIA's splatting and rasterization code](https://github.com/graphdeco-inria/gaussian-splatting). In the setup below, make sure the paths and packages are set correctly to allow importing inria code into the notebook.

We have recently tested this notebook with the following environment. Please follow [Kaolin Installation docs](https://kaolin.readthedocs.io/en/latest/notes/installation.html) to install Kaolin. 
- python 3.11.10
- cuda 12.4
- pytorch 2.5.1
- setuptools 70.1.1

In [None]:
### Install necessary packages
!pip install -q plyfile k3d matplotlib

### Import Kaolin Library and Other Requirements

In [None]:
import copy
import ipywidgets
import json
import kaolin

import matplotlib.pyplot as plt
import numpy as np
import os
import logging
import sys
import time
import threading  
import k3d
from pathlib import Path
from functools import partial
import warp as wp

import torch
import torchvision

from IPython.display import display
from ipywidgets import Button, HBox, VBox

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format="%(asctime)s|%(levelname)8s| %(message)s")
logger = logging.getLogger(__name__)

%load_ext autoreload
%autoreload 2

def log_tensor(t, name, **kwargs):
    """ Debugging util, e.g. call: log_tensor(t, 'my tensor', print_stats=True) """
    logger.info(kaolin.utils.testing.tensor_info(t, name=name, **kwargs)) 

### Import local Gaussian utils for this notebook

In order to deform the Gaussians during simulation, we define a couple functions in a utility file.

In [None]:
from gaussian_utils import transform_gaussians_lbs, pad_transforms, PHYS_NOTEBOOKS_DIR

### Setting up Inria Gaussian Splatting Codebase

This will clone and build the Gaussian renderer in a subfolder relative to this notebook: `examples/tutorial/physics/inria/`. If the build fails, you may need to set `REBUILD_INRIA=True`, fix issues and rerun this cell.

**Note:** We have occasionally run into the following [bug](https://github.com/graphdeco-inria/gaussian-splatting/issues/373), which requires adding `import <float.h>` to the imports in `examples/tutorial/physics/inria/gaussian-splatting/submodules/simple-knn/simple_knn.cu`. 

INRIA's Gaussian Splatting is not a package. Once it's built, this block will `cd` into `..../kaolin/examples/tutorial/physics/inria/gaussian-splatting directory` in order to import gaussian rendering utilities. 

In [None]:
#### Setup and Installation ###

REBUILD_INRIA = False
inria_path = os.path.join(PHYS_NOTEBOOKS_DIR, 'inria', 'gaussian-splatting')
if REBUILD_INRIA or not os.path.isdir(inria_path):
    logger.info(f'Cloning and building inria gaussian-splatting in {inria_path}')
    %cd {PHYS_NOTEBOOKS_DIR}

    ### Create an inria folder
    %mkdir inria
    %cd inria

    ### Clone the repo recursively
    !git clone --recursive https://github.com/graphdeco-inria/gaussian-splatting.git    

    ### Install the submodules
    %cd gaussian-splatting
    !git checkout --recurse-submodules 472689c
    !pip install submodules/diff-gaussian-rasterization
    !pip install submodules/simple-knn
else:
    logger.info(f'Inria gaussian-splatting already exists; cd {inria_path}')
    %cd {inria_path}


### Import Inria Gaussian Splat rendering utils

**If you get a `module not found` error, check your paths**

In [None]:
# Gaussian splatting dependencies
from utils.graphics_utils import focal2fov
from utils.system_utils import searchForMaxIteration
from gaussian_renderer import render, GaussianModel
from scene.cameras import Camera as GSCamera
from utils.general_utils import strip_symmetric, build_scaling_rotation
%pwd

## Download Splat Models from AWS
Lets grab two pre-trained 3D Gaussian Splat models from AWS.
We can unzip and set the splat model path below to the correct `.ply` file.

In [None]:
# Download and unzip the nerfsynthetic bulldozer
!if test -d output/dozer; then echo "Pretrained bulldozer splats already exist."; else wget https://nvidia-kaolin.s3.us-east-2.amazonaws.com/data/dozer.zip -P output/; unzip output/dozer.zip -d output/; fi;
model_path = 'output/dozer/point_cloud/iteration_30000/point_cloud.ply'

# Download and unzip the doll splat, captured and trained by the Kaolin team (please cite Kaolin if you use this model)
!if test -d output/doll; then echo "Pretrained doll splats already exist."; else wget https://nvidia-kaolin.s3.us-east-2.amazonaws.com/data/doll.zip -P output/; unzip output/doll.zip -d output/; fi;
model_path2 = 'output/doll/point_cloud/iteration_30000/point_cloud.ply' 

### Load 3D Gaussian Splat Models

After the setup, we can load and use Kaolin to display the splat model within the Jupyter notebook.

In [None]:
class PipelineParamsNoparse:
    """ Same as PipelineParams but without argument parser. """
    def __init__(self):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False #True # covariances will be updated during simulation
        self.debug = False

def load_model(model_path, sh_degree=3, iteration=-1):
    # Load guassians
    gaussians = GaussianModel(sh_degree)
    gaussians.load_ply(model_path)
    logger.info(f'Loaded {gaussians.get_xyz.shape[0]} gaussians from {model_path}')
    return gaussians

gaussians = load_model(model_path)
gaussians2 = load_model(model_path2)
pipeline = PipelineParamsNoparse()
background = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") # Set white bg

## Interactive Rendering Using Kaolin Visualizer

In order to easily view splats in the notebook, let's set up Gaussian Splat rendering using Kaolin camera conventions.
You should be able to see the rendering below this cell and to control the camera with your left mouse button.

In [None]:
resolution = 512
default_cam = kaolin.render.camera.Camera.from_args(
        eye=torch.ones((3,)) * 2, at=torch.zeros((3,)), up=torch.tensor([0., 0., 1.]),
        fov=torch.pi * 45 / 180, height=resolution, width=resolution)

class GaussianRenderer:
    """ Define a rendering closure. """
    def __init__(self, gaussians, downscale_factor=1):
        self.gaussians = gaussians
        self.downscale_factor = int(downscale_factor)

    def downscale_camera(self, in_cam):
        lowres_cam = copy.deepcopy(in_cam)
        lowres_cam.width = in_cam.width // self.downscale_factor
        lowres_cam.height = in_cam.height // self.downscale_factor
        return lowres_cam

    def __call__(self, camera):
        if self.downscale_factor > 1:
            camera = self.downscale_camera(camera)
        # Convert kaolin camera to inria gaussian-splatting camera
        cam = kaolin.render.camera.kaolin_camera_to_gsplats(camera, GSCamera)
        # Render gaussians using the inria rendering utilities
        render_res = render(cam, self.gaussians, pipeline, background)
        rendering = render_res["render"]
        return (torch.clamp(rendering.permute(1, 2, 0), 0, 1) * 255).to(torch.uint8).detach().cpu()

static_scene_viz = kaolin.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(default_cam), GaussianRenderer(gaussians), 
    focus_at=None, world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')
static_scene_viz.show()

## Creating and Training Simplicits Objects from Points
[Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/) is a mesh-free, representation-agnostic method for simulating elastic deformations. We can use it to simulate Gaussian Splats at interactive rates within the Jupyter notebook. In order to simulate any point-sampled geometry, such as splats, Simplicits first
_trains_ an object specific weight function representing the reduced degrees of freedom for the object. The physics solver then uses this reduced space to solve for deformations during simulation.

Next, let's use the Simplicits API within Kaolin to create, train and simulate splat objects.

First, let's set some material parameters.

In [None]:
# Physics material parameters (use approximated values, or look them up online)
# We'll create a few presets that can be used
youngs_modulus_presets = {"softest": 2000, "soft": 21000, "medium": 1e6, "stiff": 1e7}
soft_youngs_modulus = youngs_modulus_presets["soft"]  # we will use this for training
poisson_ratio = 0.45
rho = 100  # kg/m^3
approx_volume = 3  # m^3

### Sampling Within Splat Volume

Because splats tend to occupy the surface of the object, they provide poor sampling of the object's interior. This can affect the quality of the learned reduced space. To sample within the splat volume, we will use Kaolin's utility `kaolin.ops.gaussian.sample_points_in_volume`.  

In [None]:
densified_pos = kaolin.ops.gaussian.sample_points_in_volume(
    xyz=gaussians.get_xyz.detach(), 
    scale=gaussians.get_scaling.detach(),
    rotation=gaussians.get_rotation.detach(),
    opacity=gaussians.get_opacity.detach(),
    clip_samples_to_input_bbox=False
)
log_tensor(gaussians.get_xyz, 'original_pos', print_stats=True)
log_tensor(densified_pos, 'densified_pos', print_stats=True)

In [None]:
def visualize_pts_k3d(densified_pos, pos):
    plot = k3d.plot()
    plot += k3d.points(densified_pos.detach().cpu().numpy(), point_size=0.01, color=0x00ff00)
    plot += k3d.points(pos.detach().cpu().numpy(), point_size=0.02, color=0xff0000)
    plot.display()

visualize_pts_k3d(densified_pos, gaussians.get_xyz)

### Training
Next we create a `SimplicitsObject` and train its skinning weight functions using the volume samples, visualized above. The simulator will then use these reduced degrees of freedom to drive the simulation.

**Note:** since training takes a bit of time, we cache the result and reuse it next time we run the notebook.

In [None]:
# Whether to save reduced degress of freedom used by the simulator and load from cache automatically
ENABLE_SIMPLICITS_CACHING = True # set to False to always retrain

cache_dir = os.path.join(PHYS_NOTEBOOKS_DIR, 'cache')
os.makedirs(cache_dir, exist_ok=True)
logger.info(f'Caching trained simplicits objects in {cache_dir}')

In [None]:
def train_or_load_simplicits_object(points, fname):
    if not ENABLE_SIMPLICITS_CACHING or not os.path.exists(fname):
        logger.info('Training simplicits object. This will take 2-3min... ')
        start = time.time()

        # One-liner to set up Simplicits object
        sim_obj = kaolin.physics.simplicits.SimplicitsObject.create_trained(
            points,  # point samples
            soft_youngs_modulus, poisson_ratio, rho, approx_volume,  # default global values set above
            num_samples=2048, model_layers=10, num_handles=40)
        
        end = time.time()
        logger.info(f"Ended training in {end-start} seconds")

        # We'll cache the result so we can quickly rerun the notebook.
        torch.save(sim_obj, fname)
        logger.info(f"Cached training result in {fname}")
    else:
        logger.info(f'Loading cached simplicits object from: {fname}')
        sim_obj = torch.load(fname, weights_only=False)
    return sim_obj

# We'll run training on the first object's volume points
sim_obj = train_or_load_simplicits_object(
    densified_pos, os.path.join(cache_dir, 'simplicits_dozer.pt'))

## Setup Simulated Scene Using Simplicits Easy API
Lets create an empty scene with default parameters, then reset the max number of newton steps for faster runtimes.

**Note:** be patient, some of the steps below take time, as we need to build matrices used during simulation.

In [None]:
scene = kaolin.physics.simplicits.SimplicitsScene() # Create a default scene # default empty scene

scene.max_newton_steps = 3 #Convergence might not be guaranteed at few NM iterations, but runs very fast
scene.timestep = 0.03
scene.newton_hessian_regularizer = 1e-5

Now we add our object to the scene. We use 2048 cubature points to integrate over.

In [None]:
# The scene copies it into an internal SimulatableObject utility class
obj_idx = scene.add_object(sim_obj, num_qp=2048)

Lets set set gravity and floor forces on the scene

In [None]:
# Add gravity to the scene
scene.set_scene_gravity(acc_gravity=torch.tensor([0, 0, 9.8]))
# Add floor to the scene
scene.set_scene_floor(floor_height=-0.7, floor_axis=2, floor_penalty=1000, flip_floor=False)

We can play around with the material parameters of the object, indicated via object_idx

## Simulating and Interactive Visualizing 

That's it! We are ready to simulate. Let's just make sure we can visualize the simulation as it is running.

### Handling Splat Deformation

As the splats deform, we must update their attributes using the transforms predicted by the simulator.
For this, we will need the reduced degrees of freedom and ability to apply linear blend skinning to splats. These utilities can be found in `gaussian_utils.py` relative to this notebook.

In [None]:
# We will save undeformed Gaussian properties, so that we can properly transform and reset them during simulation.
rest_xyz = gaussians._xyz.clone()
rest_rot = gaussians._rotation.clone()
rest_scales = gaussians._scaling.clone()

# Precompute learning skinning weights for all splats
skinning_weights = sim_obj.skinning_weight_function(rest_xyz)

def dozer_to_timestep(transforms):
    global gaussians
    gaussians._xyz, gaussians._rotation, gaussians._scaling = \
        transform_gaussians_lbs(rest_xyz, rest_rot, rest_scales, skinning_weights, transforms)
    
# Reset to rest pose
def reset_single_object_sim():
    global gaussians
    scene.reset_scene()
    dozer_to_timestep(scene.get_object_transforms(obj_idx))

reset_single_object_sim()

### Threading

We will run simulation in a separate thread, so it is possible to interact with the viewer as the simulation is running (in fact, it's encouraged). We'll reuse these utils for this and the multi-object simulation below.

In [None]:
sim_thread_open = False
sim_thread = None

def wait_for_simulation(visualizer):
    global sim_thread_open, sim_thread
    with visualizer.out:
        if sim_thread_open:
            sim_thread.join()
            sim_thread_open = False
    
def start_simulation(sim_function, visualizer):
    wait_for_simulation(visualizer)
    
    global sim_thread_open, sim_thread
    with visualizer.out:
        sim_thread_open = True
        sim_thread = threading.Thread(target=sim_function, daemon=True)
        sim_thread.start()

def reset_simulation(reset_function, visualizer):
    with visualizer.out:
        reset_function()
    visualizer.render_update()

### Simulation: Let's bring everything together!

In [None]:
num_sim_iterations = 100
reset_single_object_sim()

def single_object_sim():
    for s in range(num_sim_iterations):
        with sim_visualizer.out:
            scene.run_sim_step()
            print(".", end="")
            with torch.no_grad():
                dozer_to_timestep(scene.get_object_transforms(obj_idx))
        sim_visualizer.render_update()

resolution = 512
sim_visualizer = kaolin.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(default_cam),
    GaussianRenderer(gaussians), fast_render=GaussianRenderer(gaussians, 8),
    focus_at=torch.tensor([0, 0, 0.0]),
    world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')

buttons = [Button(description=x) for x in
           ['Run Sim', 'Reset']]
buttons[0].on_click(lambda e: start_simulation(single_object_sim, sim_visualizer))
buttons[1].on_click(lambda e: reset_simulation(reset_single_object_sim, sim_visualizer))

sim_visualizer.render_update()
display(VBox([HBox([sim_visualizer.canvas, VBox(buttons)]), sim_visualizer.out]))

# Part 2: Multiple Objects and Collisions

It's time to make this simulation more exciting. Let's train and add the second object that we loaded above to the simulation.

### Train Second Simplicits Object

As before, we will sample and visualizer points in the object volume. Then, we'll train and cache a simplicits object.

In [None]:
densified_pos2 = kaolin.ops.gaussian.sample_points_in_volume(
    xyz=gaussians2.get_xyz.detach(), 
    scale=gaussians2.get_scaling.detach(),
    rotation=gaussians2.get_rotation.detach(),
    opacity=gaussians2.get_opacity.detach(),
    clip_samples_to_input_bbox=False
)
log_tensor(gaussians2.get_xyz, 'original_pos', print_stats=True)
log_tensor(densified_pos2, 'densified_pos', print_stats=True)

In [None]:
visualize_pts_k3d(densified_pos2, gaussians2.get_xyz)

In [None]:
# We'll run training on the second object's volume points
sim_obj2 = train_or_load_simplicits_object(
    densified_pos2, os.path.join(cache_dir, 'simplicits_doll.pt'))

### Set Up New Scene

We'll set up a new scene to make sure the previous simulation cell is still functional.

In [None]:
scene2 = kaolin.physics.simplicits.SimplicitsScene() # Create a default scene # default empty scene

scene2.max_newton_steps = 3 #Convergence might not be guaranteed at few NM iterations, but runs very fast
scene2.timestep = 0.03
scene2.newton_hessian_regularizer = 1e-5

We'll add 2 objects this time, offsetting the doll in the z direction.

In [None]:
scene2_obj_idx = scene2.add_object(sim_obj, 
                                   num_qp=2048)

scene2_obj_idx2 = scene2.add_object(sim_obj2, 
                                    num_qp=2048,
                                   init_transform=torch.tensor([[1,0,0,0],
                                                                [0,1,0,0],
                                                                [0,0,1,1], 
                                                                [0,0,0,1]], dtype=torch.float32, 
                                                               device=gaussians.get_xyz.device)) 

We'll set up forces as before.

In [None]:
# Add gravity to the scene
scene2.set_scene_gravity(acc_gravity=torch.tensor([0, 0, 9.8]))
# Add floor to the scene
scene2.set_scene_floor(floor_height=-0.7, floor_axis=2, floor_penalty=1000, flip_floor=False)

### Enable Collisions (new!)

We will enable inter-object collisions here. 

In [None]:
scene2.enable_collisions(collision_particle_radius=0.1, # radius of each collision particle - energy starts accumulating at r
                        detection_ratio=1.5, # radius * detection ratio is the area that is searched for potential contact
                        impenetrable_barrier_ratio=0.25, # radius * barrier is the distance at which energy is infinite
                        collision_penalty=1000.0, # coefficient of collision energy, force, gradient
                        max_contact_pairs=10000, # the maximum number of particle contact pairs to allow
                        friction=0.5, # friction coefficient
                    )

### Handle Deforming and Rendering Multiple Gaussians

Because the inria render is not set up to render multi-object scenes, we need to do a little work in order to visualize the simulation. Let's concatenate both objects into a single GaussianModel.

In [None]:
reset_single_object_sim()

combined_gaussians = GaussianModel(sh_degree=3)
combined_gaussians._xyz = torch.cat([
    gaussians._xyz, gaussians2._xyz
], dim=0)
combined_gaussians._scaling = torch.cat([
    gaussians._scaling, gaussians2._scaling
], dim=0)
combined_gaussians._rotation = torch.cat([
    gaussians._rotation, gaussians2._rotation
], dim=0)
combined_gaussians._opacity = torch.cat([
    gaussians._opacity, gaussians2._opacity
], dim=0)
combined_gaussians._features_dc = torch.cat([
    gaussians._features_dc, gaussians2._features_dc
], dim=0)
combined_gaussians._features_rest = torch.cat([
    gaussians._features_rest, gaussians2._features_rest
], dim=0)

# Save rest state of the combined model
combined_rest_xyz = combined_gaussians._xyz.clone()
combined_rest_rot = combined_gaussians._rotation.clone()
combined_rest_scales = combined_gaussians._scaling.clone()

Let's make sure we can deform both objects using the learned degrees of freedom, which the Simplicits simulator is using to predict deformations.

In [None]:
skinning_weights2 = sim_obj2.skinning_weight_function(gaussians2.get_xyz)
_stacked_skinning_weights = [skinning_weights, skinning_weights2]
combined_skinning_weights = torch.cat([torch.block_diag(*_stacked_skinning_weights)])

def combined_to_timestep(warp_z):
    global combined_gaussians
    # TODO: switch to using scene.get_object_transforms
    obj_tfms = wp.to_torch(warp_z, requires_grad=False).reshape((-1, 3, 4))
    transforms = pad_transforms(obj_tfms).unsqueeze(0)
    combined_gaussians._xyz, combined_gaussians._rotation, combined_gaussians._scaling = \
        transform_gaussians_lbs(combined_rest_xyz, combined_rest_rot, combined_rest_scales, combined_skinning_weights, transforms)

def reset_multi_object_sim():
    global combined_gaussians
    scene2.reset_scene()
    combined_to_timestep(scene2.sim_z)

reset_multi_object_sim()

### Simulate and Visualize

Now we are ready to run the simulation and visualize it.

In [None]:
num_sim_iterations = 100
reset_multi_object_sim()

def run_one_multisim_step():
    scene2.run_sim_step()
    with torch.no_grad():
        combined_to_timestep(scene2.sim_z)
    multi_sim_visualizer.render_update()

def multi_object_sim():
    for s in range(num_sim_iterations):
        with multi_sim_visualizer.out:
            print(".", end="")
            run_one_multisim_step()

multi_cam = kaolin.render.camera.Camera.from_args(
        eye=torch.tensor([3.0, 2.0, 3.0]), at=torch.zeros((3,)), up=torch.tensor([0., 0., 1.]),
        fov=torch.pi * 45 / 180, height=resolution, width=resolution)

resolution = 700
multi_sim_visualizer = kaolin.visualize.IpyTurntableVisualizer(
    resolution, resolution, multi_cam,
    GaussianRenderer(combined_gaussians), fast_render=GaussianRenderer(combined_gaussians, 8),
    focus_at=torch.tensor([0, 0, -0.7]),
    world_up_axis=2, max_fps=12, img_quality=75, img_format='JPEG')

buttons = [Button(description=x) for x in
           ['Run Sim', 'Reset']]
buttons[0].on_click(lambda e: start_simulation(multi_object_sim, multi_sim_visualizer))
buttons[1].on_click(lambda e: reset_simulation(reset_multi_object_sim, multi_sim_visualizer))

run_one_multisim_step()
multi_sim_visualizer.render_update()
display(VBox([HBox([multi_sim_visualizer.canvas, VBox(buttons)]), multi_sim_visualizer.out]))