# Physics Simulation and Interaction

Have you recently generated or captured an awesome 3D object and want to interact with it without pre-processing it for physics simulation? With Kaolin implementation of Simplicits it is now very easy to run deformable physics simulation on any point sampled geometry.

Let's go step-by-step on this simple example.

In [1]:
import copy, math, os, sys, logging, threading
from typing import List, Tuple
from pathlib import Path


import numpy as np
import torch
import kaolin as kal

from IPython.display import display
from ipywidgets import Button, HBox, VBox


logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
sys.path.append(str(Path("..")))
from tutorial_common import COMMON_DATA_DIR

def sample_mesh_path(fname):
    return os.path.join(COMMON_DATA_DIR, 'meshes', fname)

def print_tensor(t, name='', **kwargs):
    print(kal.utils.testing.tensor_info(t, name=name, **kwargs))
    

## Loading Geometry

Simplicits physics method works with **any** point sampled geometry - point clouds, Gaussian splats, meshes.

For the purpose of this tutorial, we will use a mesh.

In [2]:
# Import and triangulate to enable rasterization; move to GPU
mesh = kal.io.import_mesh(sample_mesh_path('armchair.obj'), triangulate=True).cuda()
    
# Normalize so it is easy to set up default camera
mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices.unsqueeze(0), normalize=True).squeeze(0) 
orig_vertices = mesh.vertices.clone()  # Also save original undeformed vertices

# Inspect
print(mesh)

SurfaceMesh object with batching strategy NONE
            vertices: [9204, 3] (torch.float32)[cuda:0]  
               faces: [18400, 3] (torch.int64)[cuda:0]  
             normals: [8507, 3] (torch.float32)[cuda:0]  
    face_normals_idx: [18400, 3] (torch.int64)[cuda:0]  
                 uvs: [10800, 2] (torch.float32)[cuda:0]  
        face_uvs_idx: [18400, 3] (torch.int64)[cuda:0]  
material_assignments: [18400] (torch.int16)[cuda:0]  
           materials: list of length 2
       face_vertices: if possible, computed on access from: (faces, vertices)
        face_normals: if possible, computed on access from: (normals, face_normals_idx) or (vertex_normals, faces) or (vertices, faces)
            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)
      vertex_normals: if possible, computed on access from: (faces, face_normals)
     vertex_tangents: if possible, computed on access from: (faces, face_vertices, face_uvs, vertex_normals)
       vertex_colors: if poss

## Visualizing Geometry

Because we are working with a mesh, Kaolin already comes with an easy render function. To use any rendering function (such as a Gaussan splat renderer),
simply define the following rendering functions that take `Camera` as input.

In [13]:
def render(in_cam, **kwargs):
    # render_res = kal.render.easy_render.render_mesh(in_cam, mesh, **kwargs)
    # img = render_res[kal.render.easy_render.RenderPass.render].squeeze(0).clamp(0, 1)
    # return {"img": (img * 255.).to(torch.uint8),
    #         "face_idx": render_res[kal.render.easy_render.RenderPass.face_idx].squeeze(0).unsqueeze(-1)}

    active_pass=kal.render.easy_render.RenderPass.render
    render_res = kal.render.easy_render.render_mesh(in_cam, mesh, **kwargs)
    
    img = render_res[active_pass]
    background_mask = (render_res[kal.render.easy_render.RenderPass.face_idx] < 0).bool()
    img2 = torch.clamp(img, 0, 1)[0]
    img2[background_mask[0]] = 1
    final = (img2 * 255.).to(torch.uint8)
    return {"img":final, "face_idx": render_res[kal.render.easy_render.RenderPass.face_idx].squeeze(0).unsqueeze(-1)}

def fast_render(in_cam, factor=8):
    lowres_cam = copy.deepcopy(in_cam)
    lowres_cam.width = in_cam.width // factor
    lowres_cam.height = in_cam.height // factor
    return render(lowres_cam)

resolution = 512
camera = kal.render.easy_render.default_camera(resolution).cuda()
visualizer = kal.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(camera), render, fast_render=fast_render,
    max_fps=24, world_up_axis=1)
display(HBox([visualizer.canvas]), visualizer.out)


HBox(children=(Canvas(height=512, width=512),))

Output()

## Preparing Geometry for Simulation

To enable simulation of any point-sampled geometry, we first need to train Simplicits objects, given their physics material parameters.

In [6]:
# Physics material parameters
soft_youngs_modulus = 1e5
hard_youngs_modulus = 1e7
poisson_ratio = 0.45
rho = 100  # kg/m^3
approx_volume = 1  # m^3

# Point samples
pts = mesh.vertices
yms = torch.full((pts.shape[0],), soft_youngs_modulus, device="cuda")
prs = torch.full((pts.shape[0],), poisson_ratio, device="cuda")
rhos = torch.full((pts.shape[0],), rho, device="cuda")

# Initialize and train a Simpicits object to enable simulation
sim_obj = kal.physics.simplicits.SimplicitsObject(pts, yms, prs, rhos, torch.tensor([approx_volume], dtype=torch.float32, device="cuda"), num_handles=10)
sim_obj.train(num_steps=10000, le_coeff=0.1)#0) # do nothing if num_handles = 0

# Saving/Loading Network
It could be useful to save the network/load the network from disk for an object.

In [7]:
#torch.save(sim_obj.model, "./results/sim_obj_model")

In [8]:
#loaded_model = torch.load("./results/sim_obj_model").to(device="cuda")
#sim_obj.model = loaded_model

## Setting up Simulation

Once the object is trained, we can set up any number of simulated scenes with that object. 

Let's set up a default scene that includes floor and can cause falling under gravity.


In [16]:
# Scene with default simulation parameters
scene1 = kal.physics.simplicits.SimplicitsScene() 
# Reduce to 3 newton steps per timestep for speed
scene1.max_newton_steps = 3

# Add an object to the scene
obj_idx = scene1.add_object(sim_obj)
obj_idx = 0  # Hack, fix return

# Add gravity to the scene
scene1.set_scene_gravity(acc_gravity=torch.tensor([0, 9.8, 0]))
# Add floor to the scene
scene1.set_scene_floor(floor_height=-1, floor_axis=1, floor_penalty=10000)
# Make object even softer
scene1.sim_obj_dict[obj_idx].set_materials(yms=torch.tensor(1e5, device="cuda", dtype=torch.float32))

## Running and Visualizing Simulation

We will run the simulation, changing object point locations (mesh vertices in this case) at every timestep and will visualize at the same time.

In [17]:
mesh.vertices = orig_vertices

resolution = 512
new_camera = kal.render.easy_render.default_camera(resolution).cuda()
new_camera.extrinsics.move_forward(1)
visualizer = kal.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(new_camera), render, fast_render=fast_render,
    max_fps=24, world_up_axis=1)

def run_sim():
    global sim_history
    global mesh
    
    scene1.sim_obj_dict[obj_idx].reset_sim_state()
    for s in range(150):
        scene1.run_sim_step()
        mesh.vertices = scene1.get_object_deformed_pts(obj_idx).squeeze()
        visualizer.render_update()


def start_simulation(b):
    run_sim()
        
button = Button(description='Run Sim')
button.on_click(start_simulation)
display(HBox([visualizer.canvas, button]), visualizer.out)

HBox(children=(Canvas(height=512, width=512), Button(description='Run Sim', style=ButtonStyle())))

Output()

## Setting up Interactive Simulation

In [18]:
def convert_offset_to_world_coords(kal_cam, point: torch.Tensor, offset: torch.Tensor):
    """ Given a point in 3D and a 2D offset in NDC coordinates, maps the offset to "world coordinate" units.
    In other words: recomputes the world coordinates had it been translated from point an "offset" amount in NDC
    space (-1 to 1).
    ps_camera: The current camera used to transform coords from world space -> camera view space -> NDC space
    point: The 3D point we translate from
    offset: A 2D offset in NDC coordinates
    """
    if point.ndim == 1:
        point = point[None]
    # Ask kaolin about the camera up and right axes in world coordinates, and move along them an "offset" amount.
    # offset is given in "post projected" NDC coordinates, so the amount isn't exactly camera-space units.
    # However, the important part is we don't move along the camera-forward axis.
    fov_x = kal_cam.fov_x
    fov_y = kal_cam.fov_y
    depth = torch.linalg.norm(point - kal_cam.cam_pos())
    offset[0] = offset[0]/1000#*torch.tan(fov_x)
    offset[1] = offset[1]/1000#*torch.tan(fov_y)
    
    translated_point = point.clone()
    translated_point += kal_cam.cam_right().squeeze(-1) * offset[0]
    translated_point += kal_cam.cam_up().squeeze(-1) * offset[1]
    return translated_point
    
def find_closest_3d_points(query_3d_pts, object_3d_pts, radius, k=10):
    """ Finds points from object_3d_pts to query_3d_pts
        Pts should be within radius and limited to a number k
    """
    # Define the radius
    r = radius
    # Calculate pairwise distances
    dists = torch.cdist(query_3d_pts, object_3d_pts)
    # Find points within radius r
    within_radius = dists <= r
    
    # Get the indices of object_pts within radius r of any query_pts
    indices_within_radius = torch.nonzero(within_radius, as_tuple=True)[1]
    if k >= indices_within_radius.shape[0]:
        return indices_within_radius, object_3d_pts[indices_within_radius]
    else:
        # Flatten distances tensor for sorting
        dists_flat = dists[within_radius]
        # Sort the distances and select top k
        sorted_indices = torch.argsort(dists_flat)[:k]
        # Select the top k indices
        top_k_indices = indices_within_radius[sorted_indices]
        # Get the points that are within the radius and in the top k closest
        points_within_radius_top_k = object_3d_pts[top_k_indices]
        # Get the points that are within the radius
        return top_k_indices, points_within_radius_top_k


### SHIFT and Click
Now, let's enable pulling on object parts during simulation.
To interact, simply press "SHIFT" and click on the mesh and drag.

In [20]:
active_face = -1
sim_history = []
mesh.vertices = orig_vertices
mouse_clicked = 0
sim_history = []
scene1.sim_obj_dict[obj_idx].reset_sim_state()

def boundary_func(pts):
    # Extract the z-coordinates (height) of the points
    heights = pts[:, 1]
    # Determine the minimum and maximum z-coordinates
    z_min = torch.min(heights)
    z_max = torch.max(heights)
    # Calculate the threshold z-coordinate for the upper 10% of the object's height
    threshold = z_min + 0.9 * (z_max - z_min)
    # Get the indices of the points in the upper 10%
    return heights >= threshold

boundary1 = scene1.sim_obj_dict[obj_idx].set_boundary_condition("boundary1", boundary_func, bdry_penalty=10000)

def additional_event_handler(visualizer, event):
    """Event handler to be provided to Kaolin's visualizer"""
    global active_face, mesh, mouse_clicked, pull_from_2d, sim_history
    
    with visualizer.out:
        if event['shiftKey']:
            #return simplicits_viz_logic(event, visualizer, mouse_clicked, boundary1, pull_from_2d, mesh)
            if event['type'] == 'mousedown':
                mouse_clicked = 1
                pull_from_2d = torch.tensor(visualizer._get_clamped_coords(event), device='cuda', dtype=torch.float)
                current_values = visualizer.get_values_under_cursor(event)
                active_face = current_values['face_idx'][0][0]
                if active_face == -1:
                    boundary1.set_pinned_verts(None, None)
                    return False
        
                bdry_inds, bdry_pts = find_closest_3d_points(mesh.vertices[mesh.faces[active_face],:], scene1.get_object_deformed_pts(obj_idx, scene1.sim_obj_dict[obj_idx].sim_pts).squeeze(), radius=0.2, k=10)
                
                boundary1.set_pinned_verts(bdry_inds, bdry_pts)
                visualizer.render_update()
                
            if event['type'] == 'mousemove':
                if mouse_clicked == 1:
                    visualizer.out.clear_output()
                    pull_to_2d = torch.tensor(visualizer._get_clamped_coords(event), device='cuda', dtype=torch.float)
                    pull_to_2d = torch.tensor(visualizer._get_clamped_coords(event), device='cuda', dtype=torch.float)

                    # Delta for clicked point
                    pxl_offset = (pull_to_2d - pull_from_2d)

                    # pxl_offset[0] /= visualizer.width 
                    # pxl_offset[1] /= visualizer.height

                    pinned_verts = boundary1.pinned_vertices
                    point = pinned_verts.mean(dim=0) # in 3D world coord
                    # Convert to opengl sign convention
                    pxl_offset[1] *= -1
                    
                    # Get updated location of bdry - TODO: (Or Perel) Need to fix this code
                    updated_point = convert_offset_to_world_coords(visualizer.camera, point, pxl_offset)
                    offset_3d = updated_point - point
                    pinned_verts += offset_3d
                    boundary1.update_pinned(pinned_verts)
                    print(f'point:{point}')
                    print(f'updated_pt:{updated_point}')
                    print(f'offset:{offset_3d}')
                
            if event['type'] == 'mouseup':
                if mouse_clicked == 1:
                    mouse_clicked = 0
                    
            return False
        return True

new_camera = kal.render.easy_render.default_camera(resolution).cuda()
new_camera.extrinsics.move_forward(1)
visualizer = kal.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(new_camera), render, fast_render=fast_render,
    max_fps=24, world_up_axis=1,
    additional_event_handler=additional_event_handler,
    additional_watched_events=['mousedown', 'mousemove', 'keydown'] # We need to now watch for key press event
)

def run_sim():
    # if shift is pressed, run our logic, else run default logic
    for step in range(500):
        with visualizer.out:
            scene1.run_sim_step()
            sim_history.append(scene1.get_object_deformed_pts(obj_idx).squeeze())
            print('.', end='')
        mesh.vertices = sim_history[-1]
        visualizer.render_update()    

# Run sim for num_steps in a different thread
# Needed for interactivity
def callback(b):
    t = threading.Thread(target=run_sim, daemon=True)
    t.start()
        
button = Button(description='Run Test')
button.on_click(callback)
display(HBox([visualizer.canvas, button]), visualizer.out)

HBox(children=(Canvas(height=512, width=512), Button(description='Run Test', style=ButtonStyle())))

Output()

...................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................