In [1]:
import jax.numpy as jnp
from jax import jit, vmap
from jax.experimental import sparse
import jax
from numpy import random
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import HTML
from base64 import b64encode

In [2]:
!pip install --quiet mediapy einops

In [3]:
import mediapy as media
from einops import reduce

We could render any sprites we want, but for this example we'll just render each particle as nice smooth gaussian circle.

In [4]:
def gaussian_kern(kernlen, nsig):
  """Returns a 2D Gaussian kernel."""
  x = jnp.linspace(-nsig, nsig, kernlen+1)
  kern1d = jnp.diff(jax.scipy.stats.norm.cdf(x))
  kern2d = jnp.outer(kern1d, kern1d)
  return kern2d/kern2d.sum()

"Draw" a sprite. This does not actually draw anything to a full sized image yet. It creates a sparse array which says where this particular sprite is poisitioned within a yet-to-be initialized full size image array. 

In [5]:
def draw_single_sprite(pos, sprite, sp_width, sp_height, out_dims):
  channels = out_dims[2]
  numel = sp_width * sp_height * channels
  data = sprite.reshape(numel)
  raw_indices = jnp.indices((channels, sp_height, sp_width), dtype=jnp.int16)
  indices = jnp.flip( raw_indices.T.reshape(-1, channels), 1)
  indices = indices.at[:, 0].set( indices[:, 0] + pos[0] )
  indices = indices.at[:, 1].set( indices[:, 1] + pos[1] )
  return sparse.BCOO((data, indices), shape=out_dims)

Create an array of these sprites at each posisiton, then trigger them all to rasterized onto a single image array by calling "sum".   
JIT to make it fast.

In [6]:
def draw_all_sprites(all_pos, all_indices, sprite_sheet, canv_dims):
  draw_bound = lambda p, sprite: draw_single_sprite(p, sprite, sprite_sheet.shape[1], sprite_sheet.shape[2], canv_dims)
  draw_all = vmap(draw_bound, in_axes=(0))
  render = draw_all(all_pos.astype(jnp.int16), sprite_sheet[all_indices]).sum(0)
  return jnp.clip(sparse.todense(render), 0, 255)# .astype(jnp.uint8) # render

fast_draw_sprites = jit(draw_all_sprites, static_argnums=(3,))

Render at double resolution and then average down (MSAA).   
This compensates for the fact that sprites being drawn at integer coordinates.   

"r_scale" can be tuned for performance/quality tradeoff. It is hardcoded because there was an issue passing it to the jit as static. 

In [7]:
def scaled_render(pos, indices, sprites, dim):
  r_scale = 4
  render_res = dim * r_scale
  img = fast_draw_sprites(
          pos * r_scale - 0.5 * sprites.shape[0], 
          indices, 
          sprites, 
          (render_res , render_res, 3)
        )
  return reduce( img, "(h sh) (w sw) c -> h w c", "mean", sh=r_scale, sw=r_scale)
  
fast_scaled_render = jit(scaled_render, static_argnums=(3,))

Compute inverse square force between particles. Double the work is done computing the symetric distance matrix (each pair of particles has their distance computed twice) becauses it's very easy to vectorize this way. Potential room for improvement.

In [8]:
def compute_forces(pos, scale, eps=0.1):
  a, b = jnp.expand_dims(pos, 1), jnp.expand_dims(pos, 0)
  diff = a - b
  dist = (diff * diff).sum(axis=-1) ** 0.5
  dist = jnp.expand_dims(dist, 2)
  force = diff / ((dist * scale) ** 3 + eps)
  return force.sum(0)

fast_compute_forces = jit(compute_forces, static_argnames=("eps"))

Integrate particle positions and velocities using Euler method.  
JIT into a nice cozy burrito. 

In [9]:
def sim_update_force(parts_pos, parts_vel, t_delta=0.05, scale=5, repel_mag=0.1, center_mag=2.5, steps=10, damp=0.99):
  p_p = jnp.array(parts_pos)
  p_v = jnp.array(parts_vel)
  # jax.experimental.loops
  for _ in range(steps):
    p_p = p_p + t_delta * p_v
    force = fast_compute_forces(p_p, scale)
    center_diff = p_p
    centering_force = center_diff / ((center_diff ** 2).sum() ** 0.5)
    p_v = damp * p_v - t_delta * (force * repel_mag + centering_force * center_mag)
  return p_p, p_v

fast_sim_update_force = jit(sim_update_force, static_argnames=("steps", "scale"))

A helper function to run a simulation and render it to a video. Nice default parameters for easy tweaking. 

In [10]:
def generate_video(
    name="test_parts.mp4", p_count=800, sprite_dim=27, vid_dim=300, brightness=80,
    t_delta=0.05, scale=25, center_mag=0.5, repel_mag=0.05, damp=0.997, total_steps=500, steps=4, seed=144):
  
  key = jax.random.PRNGKey(seed)
  p_state = jax.random.uniform(key, (p_count, 2), minval=-0.5, maxval=0.5)
  v_state = jnp.zeros((p_count, 2))
  sprite_indices = jnp.zeros((p_count,), dtype=int)
  gaussian_sprites = jnp.tile(gaussian_kern(sprite_dim, 3.0).reshape(1, sprite_dim, sprite_dim, 1), 3) * brightness

  with media.VideoWriter(name, (vid_dim, vid_dim), crf=30, fps=45) as vw:
    for i in tqdm(range(total_steps)):
      render = fast_scaled_render(
        (p_state * 0.9 + 0.5) * vid_dim,
        sprite_indices, 
        gaussian_sprites, 
        vid_dim
      )
      p_state, v_state = fast_sim_update_force(
          p_state, v_state, 
          t_delta=t_delta, scale=scale, 
          center_mag=center_mag, repel_mag=repel_mag, 
          damp=damp,
          steps=steps
      )
      vw.add_image(render)

Generate a video!  
Included are a couple extra examples demonstrating different rendering and sim parameters.

In [15]:
generate_video()

# blob
#generate_video(p_count=16000, vid_dim=1024, sprite_dim=21, scale=150, center_mag=0.25, damp=0.998, steps=4, brightness=40, total_steps=1000)

# galaxy
#generate_video(p_count=16000, vid_dim=1024, sprite_dim=21, scale=150, center_mag=0.0, repel_mag=-0.01, damp=1, steps=4, brightness=15, total_steps=1000)

mp4 = open('test_parts.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=300 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

100%|██████████| 500/500 [00:02<00:00, 176.34it/s]


In [12]:
!nvidia-smi

Thu Jul 21 02:58:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   59C    P0    28W /  70W |  13660MiB / 15109MiB |     19%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces