<a href="https://colab.research.google.com/github/Victorlouisdg/simulators/blob/main/differentiable_sparse_matrix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install meshzoo

In [None]:
import numpy as np
import jax.numpy as jnp
import jax.scipy.sparse
from jax import grad, jit, vmap
from jax import jacfwd, jacrev
from jax.lax import cond, scan
from jax.ops import index, index_add, index_update
np.set_printoptions(precision=3, suppress=True)
import meshzoo
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
matplotlib.rc('animation', html='jshtml')
from functools import partial


In [None]:
vertex_positions_uv, triangle_vertex_indices = meshzoo.rectangle_tri(
    (0.0, 0.0),
    (1.0, 1.0),
    n=2,
    variant="zigzag",  # or "up", "down", "center"
)

print(vertex_positions_uv.shape)
print(triangle_vertex_indices.shape)

vertex_positions_uv = jnp.array(vertex_positions_uv)
print(vertex_positions_uv.shape)
# print(vertex_positions_uv[jnp.array([0,1,2])])

amount_of_triangles = triangle_vertex_indices.shape[0]

In [None]:
amount_of_vertices = vertex_positions_uv.shape[0]
vertex_positions_z = np.zeros(amount_of_vertices)
vertex_positions = jnp.column_stack([vertex_positions_uv, vertex_positions_z]).flatten()
vertex_velocities = jnp.zeros_like(vertex_positions)

m0 = 100000.0
m1 = 1.0
masses = jnp.array([m0, m0, m0, m0, m0, m0, m1, m1, m1, m1, m1, m1])
M = jnp.diag(masses)

In [None]:

def plot_cloth(ax, vertex_positions, triangle_vertex_indices):
    # x, y, z = vertex_positions.transpose()
    x = vertex_positions[0::3]
    y = vertex_positions[1::3]
    z = vertex_positions[2::3]

    ax.clear()  # necessary for the animations
    ax.plot_trisurf(x, y, z, triangles=triangle_vertex_indices, color='deepskyblue')
    ax.scatter(x, y, z, c='deeppink', s=20, depthshade=False)
    for i, (xi, yi, zi) in enumerate(zip(x, y, z)):
        ax.text(xi, yi, zi, str(i), fontsize='medium', color='black', zorder=10)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.set_xlim([-0.1, 1.1])
    ax.set_ylim([-0.1, 1.1])
    ax.set_zlim([-2, 0.25])

fig = plt.figure(figsize=(8, 6), dpi=100)
ax = fig.add_subplot(111, projection='3d')
plot_cloth(ax, vertex_positions, triangle_vertex_indices)

In [None]:
def triangle_area(triangle_vertices):
    v0, v1, v2 = triangle_vertices
    return jnp.linalg.norm(jnp.cross(v1 - v0, v2 - v0)) / 2.0

In [None]:
def slice3(i):
    return slice(3*i, 3*i+3)

i, j, k = triangle_vertex_indices[0]

uv0, uv1, uv2 = vertex_positions_uv[jnp.array([i, j, k])]

x0 = vertex_positions[slice3(i)]
x1 = vertex_positions[slice3(j)]
x2 = vertex_positions[slice3(k)]

triangle_area((x0, x1, x2))

In [None]:
area_grad = grad(triangle_area)

In [None]:
area_grad((x0, x1, x2))

In [None]:
triangle_stretch_stiffness_u = np.ones(amount_of_triangles)
triangle_stretch_stiffness_v = np.ones(amount_of_triangles)

In [None]:
def deformation_gradient(vertex_positions, vertex_positions_uv):
    uv0, uv1, uv2 = vertex_positions_uv
    u0, v0 = uv0
    u1, v1 = uv1
    u2, v2 = uv2

    delta_u1 = u1 - u0
    delta_u2 = u2 - u0
    delta_v1 = v1 - v0
    delta_v2 = v2 - v0

    delta_u_matrix = jnp.array([(delta_u1, delta_u2),
                               (delta_v1, delta_v2)])
    
    inverted_delta_u_matrix = jnp.linalg.inv(delta_u_matrix)

    x0, x1, x2 = vertex_positions

    delta_x1 = x1 - x0
    delta_x2 = x2 - x0

    delta_x_matrix = jnp.column_stack((delta_x1, delta_x2))

    # Equation (9) in Baraff-Witkin.
    w_uv = delta_x_matrix @ inverted_delta_u_matrix
    
    wu, wv = jnp.hsplit(w_uv, 2)

    return wu, wv

deformation_gradient((x0, x1, x2), (uv0, uv1, uv2))

# jacrev(deformation_gradient)((x0, x1, x2), (uv0, uv1, uv2))

In [None]:
def stretch_energy(vertex_positions, vertex_positions_uv):
    wu, wv = deformation_gradient(vertex_positions, vertex_positions_uv)
    area_uv = triangle_area(vertex_positions_uv)

    Cu = area_uv * (jnp.linalg.norm(wu) - 1.0)
    Cv = area_uv * (jnp.linalg.norm(wv) - 1.0)

    Eu = 0.5 * (Cu ** 2)
    Ev = 0.5 * (Cv ** 2)

    E = Eu + Ev
    return E

print(i, j, k)
print(uv0, uv1, uv2)
print(x0, x1, x2)

stretch_energy((x0, x1, x2), (uv0, uv1, uv2))

In [None]:
grad(stretch_energy)((x0, x1, x2), (uv0, uv1, uv2))

In [None]:
def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(stretch_energy)((x0, x1, x2), (uv0, uv1, uv2))
print(len(H))
print(len(H[0]))
for i in range(3):
    for j in range(3):
        print(i, j)
        print(H[i][j])

In [None]:
    
# TODO function:

# total_energy(positions)

# grad(total_energy)


def triangle_energy(positions, positions_uv, vertex_indices):
  i, j, k = vertex_indices

  uv0 = positions_uv[i]
  uv1 = positions_uv[j]
  uv2 = positions_uv[k]

  x0 = jax.lax.dynamic_slice(positions, [i], [3])
  x1 = jax.lax.dynamic_slice(positions, [j], [3])
  x2 = jax.lax.dynamic_slice(positions, [k], [3])

#   print(i, j, k)
#   print(uv0, uv1, uv2)
#   print(x0, x1, x2)

  energy = 100.0 * stretch_energy((x0, x1, x2), (uv0, uv1, uv2))
  return energy

triangle_energy_jit = partial(triangle_energy, 
                              vertex_positions, 
                              vertex_positions_uv)


print(triangle_energy_jit(triangle_vertex_indices[0]))
print(triangle_energy_jit(triangle_vertex_indices[1]))

In [None]:
def energy_sum(carry, input, f):
    energy = f(input)
    carry = energy + carry
    output = carry
    return carry, carry

# (maybe better: vmap energy, then sum)

def total_energy(positions, positions_uv, triangle_vertex_indices):

    triangle_energy_jit = partial(triangle_energy, 
                                positions, 
                                positions_uv)
    

    energy_sum_f = partial(energy_sum, f=triangle_energy_jit)

    sum = scan(energy_sum_f, 0.0, triangle_vertex_indices)[0]
    return sum


vertex_positions = index_add(vertex_positions, index[10], 1.0)

total_energy(vertex_positions, vertex_positions_uv, triangle_vertex_indices)

In [None]:
grad(total_energy)(vertex_positions, vertex_positions_uv, triangle_vertex_indices)

In [None]:
hessian(total_energy)(vertex_positions, vertex_positions_uv, triangle_vertex_indices)

In [None]:
steps = 1000
time = 5.0
h = dt = time / steps

standard_gravity = -9.81
gravity_forces = standard_gravity * masses
gravity_forces = index_update(gravity_forces, index[0:6], 0.0)
gravity_forces = index_update(gravity_forces, index[6:8], 0.0)
gravity_forces = index_update(gravity_forces, index[9:11], 0.0)

# print(gravity_forces)


def calculate_forces_and_derivatives(positions, positions_uv, triangle_vertex_indices):
    f0 = grad(total_energy)(positions, positions_uv, triangle_vertex_indices)
    dfdx = hessian(total_energy)(positions, positions_uv, triangle_vertex_indices)
    return f0, dfdx


def simulate_step_implicit(carry, step_number, params): 
    positions, velocities = carry

    k, rest_length, positions_uv, triangle_vertex_indices = params

    f0, dfdx = calculate_forces_and_derivatives(positions, positions_uv, triangle_vertex_indices)

    v0 = velocities
    # f0 = -spring_force(positions, k, rest_length)
    f0 = f0 + gravity_forces

    I = jnp.identity(6)

    # dfdx = -spring_jacobian(positions, k, rest_length)

    A = M - (h * h) * dfdx;
    b = h * (f0 + h * (dfdx @ v0));

    A_func = lambda x : A @ x

    delta_v = jax.scipy.sparse.linalg.cg(A_func, b)[0]

    velocities_new = velocities + delta_v
    positions_new  = positions + velocities * dt


    carry = positions_new, velocities_new
    output = positions_new
    return (carry, output)

In [None]:
def simulate_full(k, rest_length, positions, velocities):
    params = k, rest_length, vertex_positions_uv, triangle_vertex_indices
    simulate_step_jit =  jit(partial(simulate_step_implicit, params=params))

    carry0 = positions, velocities
    history = scan(simulate_step_jit, carry0, np.arange(steps))[1]
    return history

history = simulate_full(0.0, 0.0, vertex_positions, vertex_velocities)

In [None]:
def animate_cloth(history, dt, fps=30):
    fig = plt.figure(figsize=(5, 5), dpi=100)
    fig.subplots_adjust(0,0,1,1,0,0) # less padding
    ax = fig.add_subplot(111, projection='3d')
    plt.close()  # prevents duplicate output 

    fps_simulation = 1 / dt
    skip = np.floor(fps_simulation / fps).astype(np.int32)
    fps_adjusted = fps_simulation / skip
    print('fps was adjusted to:', fps_adjusted)

    def animate(i):
        j = min(i * skip, len(history) - 1)
        plot_cloth(ax, history[j], triangle_vertex_indices)
        ax.text2D(0.1, 0.9, 't = {:.3f}s'.format(j * dt), transform=ax.transAxes)


    n_frames = (len(history) - 1) // skip + 1
    interval = 1000*dt*skip
    anim = animation.FuncAnimation(fig, animate, frames=n_frames, interval=interval)
    return anim

animate_cloth(history, dt)