<a href="https://colab.research.google.com/github/Victorlouisdg/Jax-Cloth-Tutorial/blob/main/Jax%20Cloth%20Tutorial%20-%20Part%201.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jax Cloth Tutorial 👕 - Part 1
The blog post accompanying this notebook can be found here.

## Forces 🚀

### Measuring Deformation 🔬


In [1]:
import jax.numpy as jnp

def deformation_gradient(positions, positions_uv):
    u0, u1, u2 = positions_uv
    Dm = jnp.column_stack([u1 - u0, u2 - u0])
    Dm_inv = jnp.linalg.inv(Dm)

    x0, x1, x2 = positions
    Ds = jnp.column_stack([x1 - x0, x2 - x0])

    F = Ds @ Dm_inv
    return F

In [2]:
triangles = jnp.array([[0, 1, 2]])
positions = jnp.array([[0.0, 0.0, 0.0],
                       [1.0, 0.0, 0.0],
                       [0.5, 0.5, 0.0]])

positions_uv = positions[:, :2]
positions_uv



DeviceArray([[0. , 0. ],
             [1. , 0. ],
             [0.5, 0.5]], dtype=float32)

### Visualization code

In [3]:
import plotly.graph_objs as go
import plotly.io as pio

pio.templates.default = "plotly_white" # also nice: "simple_white"
pio.renderers.default = "colab"
# pio.renderers.default = "svg" # static plots for Github
# pio.renderers.default = "notebook_connected" # interactive plots for NBViewer

def get_vertices_go(positions, color='deeppink'):
    x, y, z = positions.T
    vertices_go = go.Scatter3d(x=x, y=y, z=z,
                               mode='markers+text',
                               marker=dict(size=5, color=color),
                               text=[str(i) for i in range(len(positions))])
    return vertices_go

def get_mesh_go(positions, triangles, color='deepskyblue', opacity=1.0):
    x, y, z = positions.T
    i, j, k = triangles.T
    mesh_go = go.Mesh3d(x=x, y=y, z=z,
                        i=i, j=j, k=k,
                        color=color,
                        opacity=opacity,
                        hoverinfo='skip')
    return mesh_go

def make_fig(data, 
             xrange=[-0.6, 1.6], 
             yrange=[-1.1, 1.1], 
             zrange=[-1.6, 0.6], 
             eye=dict(x=-1.0, y=-1.0, z=1.0),
             center=None,
             up=None):
    fig = go.Figure(data=data)
    fig.update_layout(scene = dict(
        xaxis=dict(range=xrange, autorange=False),
        yaxis=dict(range=yrange, autorange=False),
        zaxis=dict(range=zrange, autorange=False),
        aspectmode='cube',
        camera = dict(eye=eye, center=center, up=up)
    ))

    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0)) # tight layout
    return fig



### Deformation of a Single Triangle 📐

In [4]:
data = [get_vertices_go(positions),
        get_mesh_go(positions, triangles)]

make_fig(data,
         xrange=[-0.4, 1.2],
         yrange=[-0.4, 1.2],
         zrange=[-0.5, 0.5],
         eye=dict(x=1.0, y=-0.1, z=1.0)).show()

In [5]:
deformation_gradient(positions, positions_uv)

DeviceArray([[1., 0.],
             [0., 1.],
             [0., 0.]], dtype=float32)

### Stretch Force 🚩

In [6]:
from jax.ops import index

positions_stretched = positions.at[index[2, 1]].add(0.5)
deformation_gradient(positions_stretched, positions_uv)

DeviceArray([[1., 0.],
             [0., 2.],
             [0., 0.]], dtype=float32)

In [7]:
data = [get_vertices_go(positions_stretched),
        get_vertices_go(positions),
        get_mesh_go(positions_stretched, triangles, opacity=0.6, color='orange'),
        get_mesh_go(positions, triangles)]

make_fig(data,
         xrange=[-0.4, 1.2],
         yrange=[-0.4, 1.2],
         zrange=[-0.5, 0.5],
         eye=dict(x=1.0, y=-0.1, z=1.0)).show()

In [8]:
def area(triangle_vertices):
    v0, v1, v2 = triangle_vertices
    return jnp.linalg.norm(jnp.cross(v1 - v0, v2 - v0)) / 2.0

def energy(positions, positions_uv, ku, kv):
    a = area(positions_uv)

    F = deformation_gradient(positions, positions_uv)
    wu, wv = jnp.hsplit(F, 2)
    Cu = jnp.linalg.norm(wu) - 1
    Cv = jnp.linalg.norm(wv) - 1

    Eu = 0.5 * a * ku * (Cu ** 2)
    Ev = 0.5 * a * kv * (Cv ** 2)
    return Eu + Ev

In [9]:
from jax import grad

forces = -1.0 * grad(energy)(positions_stretched, positions_uv, 100.0, 100.0)

In [10]:
def get_cones_go(positions, directions,):
    x, y, z = positions.T
    u, v, w = directions.T
    cones_go = go.Cone(x=x, y=y, z=z,
                       u=u, v=v, w=w,
                       sizemode='scaled',
                       sizeref=0.2)
    return cones_go

data = [get_vertices_go(positions_stretched),
        get_vertices_go(positions),
        get_mesh_go(positions_stretched, triangles, opacity=0.6, color='orange'),
        get_mesh_go(positions, triangles),
        get_cones_go(positions_stretched, forces)]

make_fig(data, 
         xrange=[-0.4, 1.2],
         yrange=[-0.4, 1.2],
         zrange=[-0.5, 0.5],
         eye=dict(x=1.0, y=-0.1, z=1.0)).show()

## Simulation  🌀

### Mesh Initialization 🏁

In [11]:
!pip install meshzoo



In [12]:
import meshzoo

positions_uv, triangles = meshzoo.rectangle_tri(
    (0.0, 0.0),
    (1.0, 1.0),
    n=4,
    variant="zigzag",
)

positions_uv = jnp.array(positions_uv)

amount_of_vertices = positions_uv.shape[0]
amount_of_triangles = triangles.shape[0]

positions_z = jnp.zeros(amount_of_vertices)
positions = jnp.column_stack([positions_uv, positions_z])

print("There are", amount_of_vertices, "vertices in the mesh.")
print("There are", amount_of_triangles, "triangles in the mesh.")
print("The first triangle contains vertices:", triangles[0])
print("The positions array has shape", positions.shape)

There are 25 vertices in the mesh.
There are 32 triangles in the mesh.
The first triangle contains vertices: [0 1 6]
The positions array has shape (25, 3)


In [13]:
data = [get_vertices_go(positions),
        get_mesh_go(positions, triangles)]

make_fig(data,
        xrange=[-0.1, 1.1], 
        yrange=[-0.1, 1.1], 
        zrange=[-0.6, 0.6]).show()

### Forces for the entire mesh 🎏

In [14]:
from functools import partial
from jax import vmap

def mesh_energy(positions, positions_uv, triangles, energy_fn):
    energies = vmap(energy_fn)(positions[triangles],
                               positions_uv[triangles])
    total_energy = jnp.sum(energies)
    return -total_energy

energy_fn = partial(energy, ku=100.0, kv=100.0)
mesh_energy(positions, positions_uv, triangles, energy_fn)

DeviceArray(-0., dtype=float32)

In [15]:
mesh_energy_fn = partial(mesh_energy, 
                         positions_uv=positions_uv,
                         triangles=triangles,
                         energy_fn=energy_fn)

forces_fn = grad(mesh_energy_fn)
forces_fn(positions)

DeviceArray([[0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.],
             [0., 0., 0.]], dtype=float32)

### Time Integration 🕓

In [16]:
m = 1.0 / amount_of_vertices
masses = jnp.full((amount_of_vertices, 3), m)

g = jnp.zeros((1, 3)).at[0,2].set(-9.81)
gravity = masses * jnp.repeat(g, amount_of_vertices, axis=0)

pinned = jnp.array([0, 4])
dt = 0.001

def step(carry, input, forces_fn, masses, pinned, dt):
    positions, velocities = carry

    forces = forces_fn(positions)
    forces = forces + gravity

    accelerations = forces / masses
    accelerations = accelerations.at[pinned].set(0.0)

    velocities_new = velocities + accelerations * dt
    positions_new = positions + velocities * dt

    carry = (positions_new, velocities_new)
    output = positions_new
    return (carry, output)

step_fn = partial(step,
                  forces_fn=forces_fn,
                  masses=masses,
                  pinned=pinned,
                  dt=dt)

In [17]:
def simulate(step_fn, initial_state, amount_of_steps):
    carry = initial_state
    outputs = []

    for _ in range(amount_of_steps):
        carry, output = step_fn(carry, None)
        outputs.append(output)
    
    return outputs

In [18]:
from jax.lax import scan

def simulate_jax(step_fn, initial_state, amount_of_steps):
    carry, outputs = scan(step_fn, 
                          initial_state, 
                          xs=None, 
                          length=amount_of_steps)
    return outputs

In [19]:
velocities = jnp.zeros_like(positions)
initial_state = positions, velocities

history = simulate_jax(step_fn, initial_state, 2500)

### Result ✨

In [20]:
def frame_data(positions, triangles):
    return [
       get_vertices_go(positions),
       get_mesh_go(positions, triangles)
    ]

def animate_cloth(history, triangles, dt, fps=30):

    fig = make_fig(frame_data(positions, triangles))

    fps_simulation = 1 / dt
    skip = jnp.floor(fps_simulation / fps).astype(jnp.int32)
    fps_adjusted = fps_simulation / skip
    print(f'fps was adjusted to: {fps_adjusted:.2f}')

    fig.update_layout(updatemenus=[dict(type="buttons",
                          buttons=[dict(label="Play",
                                        method="animate",
                                        args=[None, dict(frame=dict(redraw=True,fromcurrent=True, mode='immediate'))]
                                        )])])

    fig.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 1000.0 / fps_adjusted

    indices = jnp.arange(0, len(history), skip)
    frames = [go.Frame(data=frame_data(history[i], triangles)) for i in indices]

    fig.update(frames=frames)
    fig.show()

animate_cloth(history, triangles, dt)

fps was adjusted to: 30.30
