# Muscle Activation On A Simple Musculoskeletal Mesh Using Simplicits
Lets go step by step and simulate muscle-activated motion of a simple musculoskeletal mesh with no joints using the [Simplicits](https://research.nvidia.com/labs/toronto-ai/simplicits/) API.

We use a muscle mesh with fiber directions borrowed from [EMU](https://www.dgp.toronto.edu/projects/efficient-muscles/) and simulate it at interactive rates.

In [None]:
# Notebook requires k3d
!pip install k3d

In [None]:
import os, sys, copy, math
import logging
from functools import partial
import k3d

from IPython.display import display
from ipywidgets import Button, HBox, VBox
from pathlib import Path
sys.path.append(str(Path("..")))

import numpy as np
import torch
import kaolin as kal
import kaolin.physics as physics
from tutorial_common import COMMON_DATA_DIR

logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger(__name__)

def sample_file_path(fname):
    return os.path.join(COMMON_DATA_DIR, 'meshes', fname)

def print_tensor(t, name='', **kwargs):
    print(kal.utils.testing.tensor_info(t, name=name, **kwargs))
    

## Import the muscle's volumetric sample points
Since the mesh is imported for another project, we must rescale it to fit our units and do some basic cleanup.

In [None]:
# Rescales the muscle to fit simplicits units
so_pts = torch.load(sample_file_path("simple_muscle/simple_muscle_vol_pts.pt"))/10 
so_fibers = torch.load(sample_file_path("simple_muscle/simple_muscle_fibers.pt"))

# Manually zero-out bone fibers (bad model, also has fibers for bones)
# TODO: Store material parameters per-pt in the data folder
y_values = so_pts[:,1]
indices = np.nonzero((y_values < -0.45) | (y_values > 0.45))[0] # bone indices
so_fibers[indices,:] *= 0


Lets set material parameters and display the sample points used within Simplicits.

In [None]:
# Muscles should be soft (1e5 Pascals)
so_yms = 1e5*torch.ones(so_pts.shape[0])
# Bones should be harder (1e7 Pascals)
so_yms[indices] *= 1e2
so_prs = 0.45*torch.ones_like(so_yms)
# Set densities to 100kg/m^3
so_rhos = 100*torch.ones_like(so_yms)
# Guess the approximate volume
so_appx_vol = 0.3


# Plot the musculoskeletal object using k3d
# Middle portion is muscle. Ends are bones.
plot = k3d.plot()
plot += k3d.points(so_pts, point_size=0.01)
plot.display()

## Train The Objects Using Simplicit's Training
Set up the training parameters here

In [None]:
NUM_HANDLES = 10
NUM_STEPS = 15000
LR_START = 1e-3
NUM_SAMPLES = 1000
device = 'cuda'
dtype = torch.float32
ENERGY_INTERP_LINSPACE = np.linspace(0, 1, NUM_STEPS, endpoint=False)
so_model = physics.simplicits.network.SimplicitsMLP(spatial_dimensions=3, layer_width=64, num_handles=NUM_HANDLES, num_layers=8)
so_optimizer = torch.optim.Adam(so_model.parameters(), LR_START)

so_pts = torch.as_tensor(so_pts, device=device, dtype=dtype)
so_yms = torch.as_tensor(so_yms, device=device, dtype=dtype).unsqueeze(-1)
so_prs = torch.as_tensor(so_prs, device=device, dtype=dtype).unsqueeze(-1)
so_fibers = torch.as_tensor(so_fibers, device=device, dtype=dtype)
# Row-wise normalize the matrix
row_norms = torch.norm(so_fibers, p=2, dim=1, keepdim=True)
so_fibers = so_fibers / row_norms
# Replace NaNs with 0 (basic cleanup)
so_fibers = torch.nan_to_num(so_fibers, nan=0.0)

so_rhos = torch.tensor(so_rhos, device=device, dtype=dtype).unsqueeze(-1)
so_model.to(device)


### Visualize the muscle fibers
Lets visualize the fiber directions to see the direction of motion during muscle contraction.

In [None]:
sample_indices = torch.randint(low=0, high=so_pts.shape[0], size=(200,), device=so_pts.device)
# Create a K3D plot
plot = k3d.plot()
# Add points to the plot
points = k3d.points(so_pts[sample_indices].cpu().detach().numpy(), point_size=0.01, color=0xff0000)
plot += points

# Add vectors to the plot
for pt, vec in zip(so_pts[sample_indices].cpu().detach().numpy(), so_fibers[sample_indices].cpu().detach().numpy()):
    vector = k3d.vectors(
        origins=pt,
        vectors=vec/10,
        head_size=0.04,
        line_width=0.001,
        color=0x0000ff
    )
    plot += vector

# Display the plot
plot.display()

### Proceed with training.
Now lets train the object and learn its skinning eigenmodes.

In [None]:
#train_step(step, model, optim, normalized_pts, yms, prs, rhos, BATCH_SIZE, NUM_HANDLES, APPX_VOL, NUM_SAMPLES, LE_COEFF, LO_COEFF)
partial_compute_losses = partial(physics.simplicits.compute_losses, 
                             batch_size=10, 
                             num_handles=NUM_HANDLES, 
                             appx_vol=so_appx_vol, 
                             num_samples=NUM_SAMPLES, 
                             le_coeff=1e-1, 
                             lo_coeff=1e6)

so_model.train()
for i in range(15000):
    #Set grads to zero
    so_optimizer.zero_grad()
    #train a step
    le, lo = partial_compute_losses(so_model, so_pts, so_yms, so_prs, so_rhos, float(i/NUM_STEPS))
    loss = le + lo
    # Backprop over the losses
    loss.backward()
    # Take optimizer step
    so_optimizer.step()
    
    if i%1000 == 0:
        print(f'Log training step: {i}, le: {le.item()}, lo: {lo.item()}')

so_model.eval()

## Setup Simulation
Setup a simple muscle simulation with all the following simulation parameters and precomputations. This step is the same as the `simplicits_low_level_api.ipynb` with the additional `muscle_e, muscle_g, muscle_h` energy, gradient and hessian.

In [None]:
NUM_STEPS = 100
DT = 0.05
FLOOR_PLANE = -1
PENALTY_WEIGHT = 10000
NUM_SAMPLES = 2000
device = 'cuda'
dtype = torch.float32
MAX_NEWTON_STEPS=20
MAX_LS_STEPS = 30

sample_indices = torch.randint(low=0, high=so_pts.shape[0], size=(NUM_SAMPLES,), device=so_pts.device)

sim_pts = so_pts[sample_indices]
sim_normalized_pts = so_pts[sample_indices] #leave it the same as sim_pts
sim_yms = 1*so_yms[sample_indices]
sim_prs = so_prs[sample_indices]
sim_rhos = 1*so_rhos[sample_indices]
sim_weights = torch.cat((so_model(sim_normalized_pts), torch.ones((sim_normalized_pts.shape[0], 1), device=device)), dim=1)

y_vals = sim_pts[:,1]
indices = torch.nonzero((y_vals < -0.45) | (y_vals>0.45)).squeeze()
sim_fiber_vecs = torch.tensor([0,1,0], device='cuda', dtype=torch.float).expand(NUM_SAMPLES,3).clone()
sim_fiber_vecs[indices,:] =0
model_plus_rigid = lambda pts: torch.cat((so_model(pts), torch.ones((pts.shape[0], 1), device=device)), dim=1)

z = torch.zeros(sim_weights.shape[1]*12 , dtype=dtype, device = device).unsqueeze(-1)
z_prev = z.clone().detach()
z_dot = torch.zeros_like(z, device=device)
x0_flat = sim_pts.flatten().unsqueeze(-1)


M, invM = physics.simplicits.precomputed.lumped_mass_matrix(sim_rhos, so_appx_vol, dim = 3)
dFdz = physics.simplicits.precomputed.jacobian_dF_dz(model_plus_rigid, sim_normalized_pts, z).detach()
dxdz = torch.autograd.functional.jacobian(lambda x: physics.simplicits.utils.weight_function_lbs(sim_pts, tfms = x.reshape(-1,3,4).unsqueeze(0), fcn = model_plus_rigid).flatten(), z.flatten())
bigI = torch.tile(torch.eye(3, device=device).flatten().unsqueeze(dim=1), (NUM_SAMPLES,1)).detach()
B = physics.simplicits.precomputed.lbs_matrix(sim_pts, sim_weights).detach()

# 3*num samples gravities per sample point
grav = torch.tensor([0, 9.8, 0], device=device)

BMB = B.T @ M @ B
BinvMB = B.T @ invM @ B

print(" Density: ",str(sim_rhos[0].item())+"kg/m^3\n", 
      "Youngs Mod: ", str(sim_yms[0].item())+"Pa\n", 
      "Poiss Ratio: ", str(sim_prs[0].item())+"\n", 
      "Appx Vol: ", str(so_appx_vol)+"m^3\n")

In [None]:
#########################SETUP MATERIAL AND SCENE FORCES######################################################
mus, lams = physics.materials.utils.to_lame(sim_yms, sim_prs)
material_object = physics.materials.NeohookeanMaterial(sim_yms, sim_prs)
muscle_object = physics.materials.MuscleMaterial(sim_fiber_vecs)
gravity_object = physics.utils.Gravity(rhos=sim_rhos, acceleration=grav)
floor_object = physics.utils.Floor(floor_height=FLOOR_PLANE, floor_axis=1)
bdry_cond = physics.utils.Boundary()
bdry_indx = torch.nonzero(sim_pts[:,1]>0.45, as_tuple=False).squeeze()
bdry_pos = sim_pts[bdry_indx,:]
bdry_cond.set_pinned_verts(bdry_indx, bdry_pos)
integration_sampling = torch.tensor(so_appx_vol/NUM_SAMPLES, device=device, dtype=sim_pts.dtype)

#######################Physics Energy, Forces, Hessians########################################################
partial_bdry_e = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_energy(bdry_cond, B, coeff=PENALTY_WEIGHT, integration_sampling=None)
partial_bdry_g = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_gradient(bdry_cond, B, coeff=PENALTY_WEIGHT, integration_sampling=None)
partial_bdry_h = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_hessian(bdry_cond, B, coeff=PENALTY_WEIGHT, integration_sampling=None)

partial_grav_e = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_energy(gravity_object, B, coeff=1, integration_sampling=integration_sampling) 
partial_grav_g = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_gradient(gravity_object, B, coeff=1, integration_sampling=integration_sampling) 
partial_grav_h = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_hessian(gravity_object, B, coeff=1, integration_sampling=integration_sampling) 

partial_material_e = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_energy(material_object, dFdz, coeff=1, integration_sampling=integration_sampling)
partial_material_g = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_gradient(material_object, dFdz, coeff=1, integration_sampling=integration_sampling)
partial_material_h = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_hessian(material_object, dFdz, coeff=1, integration_sampling=integration_sampling)

# Add the muscle energy, gradient, hessian to the lists of energies, grads, hessians
partial_muscle_e = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_energy(muscle_object, dFdz, coeff=1, integration_sampling=integration_sampling)
partial_muscle_g = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_gradient(muscle_object, dFdz, coeff=1, integration_sampling=integration_sampling)
partial_muscle_h = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_material_hessian(muscle_object, dFdz, coeff=1, integration_sampling=integration_sampling)

partial_floor_e = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_energy(floor_object, B, coeff=PENALTY_WEIGHT, integration_sampling=None)
partial_floor_g = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_gradient(floor_object, B, coeff=PENALTY_WEIGHT, integration_sampling=None)
partial_floor_h = physics.simplicits.simplicits_scene_forces.generate_fcn_simplicits_scene_hessian(floor_object, B, coeff=PENALTY_WEIGHT, integration_sampling=None)
###############################################################################

####################Backwards Euler Functions###########################################################
def potential_sum(output, z, z_dot, B, dFdz, x0_flat, bigI, defo_grad_fcns = [], pt_wise_fcns = []):
    # updates the quantity calculated in the output value
    F_ele = torch.matmul(dFdz, z) + bigI
    x_flat = B @ z + x0_flat
    x = x_flat.reshape(-1,3)
    for e in defo_grad_fcns:
        output += e(F_ele)
    for e in pt_wise_fcns:
        output += e(x)
        
def newton_E(z, z_prev, z_dot, B, BMB, dt, x0_flat, dFdz, bigI, defo_grad_energies = [], pt_wise_energies = []):
    pe_sum = torch.tensor([0], device=device, dtype=dtype)
    potential_sum(pe_sum, z, z_dot, B, dFdz, x0_flat, bigI, defo_grad_energies, pt_wise_energies)
    return 0.5 * z.T @ BMB @ z - z.T @ BMB @ z_prev - dt * z.T @ BMB @ z_dot + dt * dt * pe_sum

def newton_G(z, z_prev, z_dot, B, BMB, dt, x0_flat, dFdz, bigI, defo_grad_gradients = [], pt_wise_gradients = []):
    pe_grad_sum = torch.zeros_like(z)
    potential_sum(pe_grad_sum, z, z_dot, B, dFdz, x0_flat, bigI, defo_grad_gradients, pt_wise_gradients)
    return BMB @ z - BMB @ z_prev - dt * BMB @ z_dot + dt * dt * pe_grad_sum

def newton_H(z, z_prev, z_dot, B, BMB, dt, x0_flat, dFdz, bigI, defo_grad_hessians = [], pt_wise_hessians = []):
    pe_hess_sum = torch.zeros(z.shape[0], z.shape[0], device=device, dtype=dtype)
    potential_sum(pe_hess_sum, z, z_dot, B, dFdz, x0_flat, bigI, defo_grad_hessians, pt_wise_hessians)
    return BMB  + dt * dt * pe_hess_sum
    
##########################Backwards Euler Partials#####################################################
partial_newton_E = partial(newton_E, 
                           B=B.detach(), 
                           BMB = BMB.detach(), 
                           dt=DT, 
                           x0_flat=x0_flat.detach(), 
                           dFdz=dFdz.detach(), 
                           bigI=bigI.detach(),
                           defo_grad_energies=[partial_material_e, partial_muscle_e],
                           pt_wise_energies=[partial_grav_e, partial_floor_e, partial_bdry_e])
partial_newton_G = partial(newton_G, 
                           B=B.detach(), 
                           BMB = BMB.detach(), 
                           dt=DT, 
                           x0_flat=x0_flat.detach(), 
                           dFdz=dFdz.detach(), 
                           bigI=bigI.detach(),
                           defo_grad_gradients=[partial_material_g, partial_muscle_g],
                           pt_wise_gradients=[partial_grav_g, partial_floor_g, partial_bdry_g])
partial_newton_H = partial(newton_H, 
                           B=B.detach(), 
                           BMB = BMB.detach(), 
                           dt=DT, 
                           x0_flat=x0_flat.detach(), 
                           dFdz=dFdz.detach(), 
                           bigI=bigI.detach(),
                           defo_grad_hessians=[partial_material_h, partial_muscle_h],
                           pt_wise_hessians=[partial_grav_h, partial_floor_h, partial_bdry_h])

## Loading And Displaying Surface Mesh

Simplicits results can be displayed in a variety of ways.
For the purpose of this tutorial, we will use a mesh to visualize muscle contraction.

In [None]:
# Import and triangulate to enable rasterization; move to GPU
mesh = kal.io.import_mesh(sample_file_path('simple_muscle/simple_muscle.obj')).cuda()
    
# Normalize so it is easy to set up default camera
mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices.unsqueeze(0), normalize=True).squeeze(0) 
orig_vertices = mesh.vertices.clone()  # Also save original undeformed vertices


## Visualizing Geometry

Because we are working with a mesh, Kaolin already comes with an easy render function. To use any rendering function (such as a Gaussan splat renderer),
simply define the following rendering functions that take `Camera` as input.

In [None]:
def render(in_cam, **kwargs):
    render_res = kal.render.easy_render.render_mesh(in_cam, mesh, **kwargs)
    img = render_res[kal.render.easy_render.RenderPass.render].squeeze(0).clamp(0, 1)

    background_mask = (render_res[kal.render.easy_render.RenderPass.face_idx] < 0).bool()
    img2 = torch.clamp(img, 0, 1)
    img2[background_mask[0]] = 0 #background white
    final = (img2 * 255.).to(torch.uint8)
    
    return final

def fast_render(in_cam, factor=8):
    lowres_cam = copy.deepcopy(in_cam)
    lowres_cam.width = in_cam.width // factor
    lowres_cam.height = in_cam.height // factor
    return render(lowres_cam)
    
mesh.vertices = orig_vertices
resolution = 512
new_camera = kal.render.easy_render.default_camera(resolution).cuda()
new_camera.extrinsics.move_forward(1)
visualizer = kal.visualize.IpyTurntableVisualizer(
    resolution, resolution, copy.deepcopy(new_camera), render, fast_render=fast_render,
    max_fps=24, world_up_axis=1)

## Simulate Muscle Activation
Here's the exposed simulation loop. 
We include the logic to start activating the muscle linearly at timestep 20.

In [None]:
z = torch.zeros(sim_weights.shape[1]*12 , dtype=dtype, device = device).unsqueeze(-1)
z_prev = z.clone().detach()
z_dot = torch.zeros_like(z, device=device).detach()
x0_flat = sim_pts.flatten().unsqueeze(-1)
muscle_object.set_activation(0)
states = [z.clone().detach()]

def run_sim():
    global states, mesh, z, z_prev, z_dot, x0_flat, muscle_object, NUM_STEPS

    for time_step in range(int(NUM_STEPS)):
        print(f"Timestep:{time_step}")
        # Activate linearly after 20 steps
        if(time_step>20):
            muscle_object.set_activation(((time_step-20)/(NUM_STEPS-20))*100000)
            #pass
        z_prev = z.clone().detach()
        more_partial_newton_E = partial(partial_newton_E, z_prev=z_prev, z_dot=z_dot)
        more_partial_newton_G = partial(partial_newton_G, z_prev=z_prev, z_dot=z_dot)
        more_partial_newton_H = partial(partial_newton_H, z_prev=z_prev, z_dot=z_dot)
        z = physics.utils.optimization.newtons_method(z, more_partial_newton_E, more_partial_newton_G, more_partial_newton_H, conv_criteria=1)
        F_ele = torch.matmul(dFdz, z) + bigI
        x_pts = (B @ z + x0_flat).reshape(-1,3)
        print(f'\t Floor E:{partial_floor_e(x_pts).item()}, Grav E:{partial_grav_e(x_pts).item()}, Bdry E:{partial_bdry_e(x_pts).item()}, Elastic E:{partial_material_e(F_ele).item()}, Muscle E:{partial_muscle_e(F_ele).item()}')
        with torch.no_grad():
            z_dot = (z - z_prev)/DT
        states.append(z.clone().detach())
        with torch.no_grad():
            x = physics.simplicits.utils.weight_function_lbs(orig_vertices, tfms = z.reshape(-1,3,4).unsqueeze(0), fcn = model_plus_rigid).squeeze()
            mesh.vertices = x
            visualizer.render_update()
    print("done")


def start_simulation(b):
    run_sim()

button = Button(description='Run Sim')
button.on_click(start_simulation)
display(HBox([visualizer.canvas, button]), visualizer.out)