# Simplicits Easy API Demo

[Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/) is a mesh-free, representation-agnostic way to simulate elastic deformations.

In [1]:
# Imports
import builtins
import copy
import threading
from contextlib import contextmanager
from pathlib import Path

import torch
import kaolin as kal
import k3d
from loguru import logger
from tqdm.auto import tqdm
from ipywidgets import Button, HBox, VBox
from IPython.display import display

logger.info("Imports loaded")

ImportError: DLL load failed while importing _C: The specified module could not be found.

In [2]:
# Helper functions
@contextmanager
def training_progress_hook(total_steps):
    """Hook into Kaolin's training loop to show progress bar."""
    import kaolin.physics.simplicits.easy_api as easy_api
    
    original_range = getattr(easy_api, "range", builtins.range)
    
    def tqdm_range(*args):
        rng = original_range(*args)
        if hasattr(rng, "__len__") and len(rng) == total_steps:
            return tqdm(rng, desc="Training", total=total_steps, unit="step", leave=False)
        return rng
    
    easy_api.range = tqdm_range
    try:
        yield
    finally:
        if original_range is builtins.range:
            if hasattr(easy_api, "range"):
                delattr(easy_api, "range")
        else:
            easy_api.range = original_range


def load_and_sample_mesh(mesh_path, num_samples=200_000):
    """Load mesh, center it, and sample interior points."""
    mesh = kal.io.import_mesh(mesh_path, triangulate=True).cuda()
    mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices.unsqueeze(0), normalize=True).squeeze(0)
    orig_vertices = mesh.vertices.clone()
    logger.info(f"Loaded mesh: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
    
    # Sample interior points
    min_corner = orig_vertices.min(dim=0).values
    max_corner = orig_vertices.max(dim=0).values
    uniform_pts = torch.rand(num_samples, 3, device='cuda') * (max_corner - min_corner) + min_corner
    inside = kal.ops.mesh.check_sign(
        mesh.vertices.unsqueeze(0), mesh.faces, uniform_pts.unsqueeze(0), hash_resolution=512
    ).squeeze()
    pts = uniform_pts[inside]
    logger.info(f"Sampled {len(pts)} interior points")
    
    return mesh, orig_vertices, pts


def create_sim_object(pts, youngs=1e5, poisson=0.45, density=500.0, volume=0.5, handles=5, steps=10000):
    """Create and train a Simplicits simulation object with progress bar."""
    yms = torch.full((pts.shape[0],), youngs, device="cuda")
    prs = torch.full((pts.shape[0],), poisson, device="cuda")
    rhos = torch.full((pts.shape[0],), density, device="cuda")
    
    logger.info(f"Training Simplicits object: {steps} steps, {handles} handles")
    
    # Use the training progress hook
    with training_progress_hook(steps):
        sim_obj = kal.physics.simplicits.SimplicitsObject.create_trained(
            pts, yms, prs, rhos, volume,
            num_handles=handles,
            training_num_steps=steps,
            training_lr_start=1e-3,
            training_lr_end=1e-3,
            training_le_coeff=1e-1,
            training_lo_coeff=1e6,
            training_log_every=max(1, steps // 10),
            normalize_for_training=True
        )
    
    logger.info("Training complete")
    return sim_obj


def setup_scene(sim_obj, gravity=(0, 9.8, 0), floor_height=-0.8, timestep=0.03):
    """Create scene and add physics."""
    scene = kal.physics.simplicits.SimplicitsScene()
    scene.max_newton_steps = 5
    scene.timestep = timestep
    scene.direct_solve = True
    
    obj_idx = scene.add_object(sim_obj)
    scene.set_scene_gravity(acc_gravity=torch.tensor(gravity))
    scene.set_scene_floor(floor_height=floor_height, floor_axis=1, floor_penalty=1000)
    
    return scene, obj_idx


logger.info("Helper functions defined")

[32m2025-09-30 14:12:34.564[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m86[0m - [1mHelper functions defined[0m


In [4]:
mesh_path = "assets/fox.obj"

mesh, orig_vertices, pts = load_and_sample_mesh(mesh_path)

# Add floor to mesh
floor_height = -0.8
floor_size = 5.0
num_obj_verts = mesh.vertices.shape[0]

floor_verts = torch.tensor([
    [-floor_size, floor_height, -floor_size],
    [floor_size, floor_height, -floor_size],
    [floor_size, floor_height, floor_size],
    [-floor_size, floor_height, floor_size]
], device='cuda', dtype=torch.float32)

floor_faces = torch.tensor([
    [0, 1, 2],
    [0, 2, 3]
], device='cuda', dtype=torch.long) + num_obj_verts

combined_verts = torch.cat([mesh.vertices, floor_verts], dim=0)
combined_faces = torch.cat([mesh.faces, floor_faces], dim=0)
mesh = kal.rep.SurfaceMesh(vertices=combined_verts, faces=combined_faces)
orig_vertices = torch.cat([orig_vertices, floor_verts], dim=0)

logger.info(f"Added floor: {mesh.vertices.shape[0]} total vertices, {mesh.faces.shape[0]} total faces")

# Visualize sampled points
logger.info("Visualizing sampled points...")
plot = k3d.plot()
plot += k3d.points(pts.cpu().detach().numpy(), point_size=0.01)
plot.display()

[32m2025-09-30 14:49:16.042[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_and_sample_mesh[0m:[36m31[0m - [1mLoaded mesh: 5002 vertices, 10000 faces[0m
[32m2025-09-30 14:49:16.181[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_and_sample_mesh[0m:[36m41[0m - [1mSampled 44055 interior points[0m
[32m2025-09-30 14:49:16.184[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mAdded floor: 5006 total vertices, 10002 total faces[0m
[32m2025-09-30 14:49:16.184[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mVisualizing sampled points...[0m


Output()

In [None]:

sim_obj = create_sim_object(pts)
scene, obj_idx = setup_scene(sim_obj)

logger.info("Setup complete")

In [None]:
num_obj_verts = orig_vertices.shape[0] - 4
num_obj_faces = mesh.faces.shape[0] - 2
floor_verts = orig_vertices[num_obj_verts:].clone()
obj_faces = mesh.faces[:num_obj_faces].clone()
floor_faces = mesh.faces[num_obj_faces:].clone()

mesh.vertices = orig_vertices.clone()

resolution = 512
camera = kal.render.easy_render.default_camera(resolution).cuda()
light_direction = kal.render.lighting.sg_direction_from_azimuth_elevation(1., 1.)
lighting = kal.render.lighting.SgLightingParameters(amplitude=3., sharpness=5., direction=light_direction).cuda()

def render(in_cam):
    active_pass = kal.render.easy_render.RenderPass.render
    render_res = kal.render.easy_render.render_mesh(in_cam, mesh, lighting=lighting)
    
    img = render_res[active_pass]
    background_mask = (render_res[kal.render.easy_render.RenderPass.face_idx] < 0).bool()
    img2 = torch.clamp(img, 0, 1)[0]
    img2[background_mask[0]] = 1
    final = (img2 * 255.).to(torch.uint8)
    return {"img": final}

def fast_render(in_cam, factor=8):
    lowres_cam = copy.deepcopy(in_cam)
    lowres_cam.width = in_cam.width // factor
    lowres_cam.height = in_cam.height // factor
    return render(lowres_cam)

global sim_thread_open, sim_thread
sim_thread_open = False
sim_thread = None

def reset_simulation(visualizer):
    nonlocal mesh
    with visualizer.out:
        scene.reset_scene()
    obj_verts = scene.get_object_deformed_pts(obj_idx, orig_vertices[:num_obj_verts])
    combined_verts = torch.cat([obj_verts, floor_verts], dim=0)
    combined_faces = torch.cat([obj_faces, floor_faces], dim=0)
    mesh = kal.rep.SurfaceMesh(vertices=combined_verts, faces=combined_faces)
    visualizer.render_update()

def run_sim():
    nonlocal mesh
    for _ in range(100):
        with visualizer.out:
            scene.run_sim_step()
            print(".", end="")
        obj_verts = scene.get_object_deformed_pts(obj_idx, orig_vertices[:num_obj_verts])
        combined_verts = torch.cat([obj_verts, floor_verts], dim=0)
        combined_faces = torch.cat([obj_faces, floor_faces], dim=0)
        mesh = kal.rep.SurfaceMesh(vertices=combined_verts, faces=combined_faces)
        visualizer.render_update()

def start_simulation(b):
    global sim_thread_open, sim_thread
    with visualizer.out:
        if sim_thread_open:
            sim_thread.join()
            sim_thread_open = False
        sim_thread_open = True
        sim_thread = threading.Thread(target=run_sim, daemon=True)
        sim_thread.start()

visualizer = kal.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(camera), render, fast_render=fast_render,
    max_fps=24, world_up_axis=1
)

buttons = [Button(description=x) for x in ['Run Sim', 'Reset']]
buttons[0].on_click(lambda e: start_simulation(e))
buttons[1].on_click(lambda e: reset_simulation(visualizer))

reset_simulation(visualizer)
display(HBox([visualizer.canvas, VBox(buttons)]), visualizer.out)