# Active matter model simulation

We will be simulating the 2D system of interacting self-propelled chiral particles depicted in part 1.A of ["Learning hydrodynamic equations for active matter from particle simulations and experiments"](https://www.pnas.org/doi/10.1073/pnas.2206994120) published by Rohit Supekar in 2023. We will be using [JAX](https://www.github.com/google/jax) and [JAX, MD](https://www.github.com/google/jax-md) for this task.

## Setup

First, we will install necessary packages as well as define important graphing/modelling functions. 

In [None]:
%matplotlib inline

# Imports
import numpy as onp

from jax import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
from jax import random
from jax import jit
from jax import vmap
from jax import lax

vectorize = np.vectorize

from functools import partial

from collections import namedtuple
import base64

import IPython
from IPython.display import HTML, display
import time

import os

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, util
from jax_md.util import f32

import ffmpeg

# Plotting

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as ani
import seaborn as sns

# Tell Matplotlib how to embed animations
plt.rcParams['animation.html'] = 'jshtml'      # or 'html5'

sns.set_style(style='white')

dark_color = [56 / 256] * 3
light_color = [213 / 256] * 3
axis_color = 'white'

def format_plot(x='', y='', grid=True):
  ax = plt.gca()

  ax.spines['bottom'].set_color(axis_color)
  ax.spines['top'].set_color(axis_color)
  ax.spines['right'].set_color(axis_color)
  ax.spines['left'].set_color(axis_color)

  ax.tick_params(axis='x', colors=axis_color)
  ax.tick_params(axis='y', colors=axis_color)
  ax.yaxis.label.set_color(axis_color)
  ax.xaxis.label.set_color(axis_color)
  ax.set_facecolor(dark_color)

  plt.grid(grid)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)

def finalize_plot(shape=(1, 1)):
  plt.gcf().patch.set_facecolor(dark_color)
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

# Progress Bars

def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    prog = int(100 * value/max)
    return "Progress: " + "*" * prog + " " * (100 - prog) + f"| {value}/{max} "

normalize = lambda v: v / np.linalg.norm(v, axis=1, keepdims=True)

# Rendering

def render(box_size, states, name="supekar"):
  """
  Creates a rendering of the system. Edit this to
  make it run on matplotlib.

  The Chiral namedtuple has the form [R, theta] where R is an ndarray of
  shape [particle_count, spatial_dimension] while theta is an ndarray of shape
  [particle_count].

  Inputs:
    box_size (float): size-length of box
    states (Chiral namedtuple): special chiral datatype.

  Output:
    anim (Animation): animated rendering of box state, runs at 100fps
  """
  # if states is a TBA: retrieve R and theta
  if isinstance(states, Chiral):
    R = onp.reshape(states.R, (1,) + states.R.shape)
    theta = onp.reshape(states.theta, (1,) + states.theta.shape)

  # if states is a list (sequence of simul. frames)
  elif isinstance(states, list):
    # if all indiv. state recorded in states is a boid: stack all R and theta vectors into an array
    if all([isinstance(x, Chiral) for x in states]):
      R, theta = zip(*states)
      R = onp.stack(R)
      theta = onp.stack(theta)

  # retrieve number of frames
  frames = R.shape[0]

  fig, ax = plt.subplots()

  # formatting plot
  ax.set_xlim(0, box_size)
  ax.set_ylim(0, box_size)

  # single frame rendering
  def renderer_code(frame_num=0):
    """
    Creates an artist list of all the simulation frames.
    Only works for 2D.
    """
    if frame_num == frames:
      return []

    # particles data
    curr_R = R[frame_num]
    curr_theta = theta[frame_num]
    curr_x = curr_R[:, 0]
    curr_y = curr_R[:, 1]

    # rendering: USE COLOR TO ENCODE POLARIZATION/ ANGLE OF CHIRALS.
    chiral_plot = ax.scatter(curr_x, curr_y, c=curr_theta, s=.005,
                             cmap="hsv", vmin=0, vmax=2 * onp.pi)
    scatter_plot = [chiral_plot]

    render_rest = renderer_code(frame_num + 1)
    render_rest.insert(0, scatter_plot)

    return render_rest

  artists = renderer_code()

  # COLORBAR FOR ANGLE
  fig.colorbar(artists[0][0])

  # build the animation
  anim = ani.ArtistAnimation(fig, artists,
                            interval=10, repeat_delay=1000, blit=False)

  plt.close(fig)            # keep the static PNG from appearing
  anim.save(f"{name}.mp4", writer="ffmpeg", dpi=150)
  # return anim
  # display step not done as it is memory-consuming, only used during initial debugging.

## Chiral 

To model the chiral particles Supekar described in the above paper, we can define a `Chiral` type data that stores data for the active matter system in 2 arrays: `R` is an `ndarray` of shape `[particle_count, spatial_dim]` while `theta` is an `ndarray` of shape `[particle_count]`

In [None]:
Chiral = namedtuple('Chiral', ['R', 'theta'])

To test our code, we instantiate 200 chiral particles in a 2D box with $L = 100$. We will use [periodic boundary conditions](https://en.wikipedia.org/wiki/Periodic_boundary_conditions) for our simulation, similar to conditions in the original paper. 

In [None]:
# Simulation params
box_size = 100      # float specifying side length of box
N = 12000             # number of particles in our system

# Create RNG state to draw random numbers (see LINK).
rng = random.PRNGKey(0)

# Periodic boundary conditions:
displacement, shift = space.periodic(box_size)

# Initialize particles
rng, v_rng, omega_rng, R_rng, theta_rng = random.split(rng, 5)

# Initialize Chiral
chiral = Chiral(
    R = box_size * random.uniform(R_rng, (N, 2)),
    theta = random.uniform(theta_rng, (N,), maxval=2*onp.pi)
)

In [None]:
display(render(box_size, chiral, name="supekar_init"))

## Physical behavior

In the paper, Supekar used the following model, which is known to capture essential aspects of the experimentally observed self-organization of protein filaments, bacterial swarms, and cell monolayers:

$$\frac{d\textbf{x}_i}{dt} = v_i \textbf{p}_i$$

$$\frac{d\theta_i}{dt} = \Omega_i + g\sum_{j \in \mathcal N_i} sin(\theta_j - \theta_i) + \sqrt{2D_r}\eta_i$$

Here, $\textbf{p}_i = (\cos{\theta_i}, \sin{\theta_i})^T$ is the orientation vector of the $i^{th}$ particle, $\eta_i(t)$ denotes orientational Gaussian white noise with zero mean and $ \langle \eta_i(t) \eta_j(t') \rangle = \delta_{ij} \delta(t - t')$, $D_r$ is the rotational diffusion constant, and $g > 0$ determines the alignment interaction strength between particles i and j within neighborhood $N_i$ with chosen radius $R = 1$.

The original paper considered particles drawn from the distribution $ \tilde p(v_i, R_i) = G(v_i; \mu_v, \sigma_v) G(R_i; \mu_R, \sigma_R) $, where $G(x; \mu_x, \sigma_x) $ represents a Gaussian distribution with mean $\mu_x$ and standard deviation $\sigma_x$. $p(v_i, \Omega_i)$ is then implicitly defined through the relation $v_i = \Omega_i R_i$. The original paper used $\mu_v = 1, \sigma_v = 0.4, \mu_R = 2.2, \sigma_R = 1.7, g = 0.018, D_r = 0.009$, and all particles with $\Omega_i > 1.4$ was removed; the simulation was ran at time step $dt = 0.0176$, and data was saved at intervals of $\Delta t = 0.44 = 25 dt $



In [None]:
@vmap
def normal(v, theta):
  return np.array([v * np.cos(theta), v * np.sin(theta)])

@vmap
def angle_sum(theta1, theta2):
   init_theta = theta1 + theta2
   return np.where(init_theta < 0, init_theta + 2 * onp.pi,
                   np.where(init_theta > 2 * onp.pi, init_theta - 2 * onp.pi, init_theta))

def align_fn(dr, theta_i, theta_j):
   align_spd = np.sin(theta_j - theta_i)
   return np.where(dr < 1., align_spd, 0.)

def align_tot(R, theta):
   # Alignment factor
   align = vmap(vmap(align_fn, (0, None, 0)), (0, 0, None))

   # Displacement between all points
   dR = space.map_product(displacement)(R, R)

   return np.sum(align(dR, theta, theta))

def dynamics(v, omega, dt=0.0176, g=0.018, D_r=0.009):
    @jit
    def update(_, state):
        R, theta = state['chiral']
        key = state['key']

        # GENERALIZE LTR
        # Forward-Euler scheme
        n = normal(v, theta)
        dR = n * dt
        dtheta = (omega + g * align_tot(R, theta)) * dt

        # Stochastic step:
        key, split = random.split(key)
        stheta = np.sqrt(2 * D_r * dt) * random.normal(split, theta.shape)

        state['chiral'] = Chiral(
           shift(R, dR),
           angle_sum(theta, dtheta + stheta)
        )
        state['key'] = key

        return state

    return update

def init(mu_v=1, sigma_v=0.4, mu_r=2.2, sigma_r=1.7, N=N):
   # sample from gaussian distrib + filter for positive vals
   spd = mu_v + sigma_v * random.normal(v_rng, (10 * N, ))
   r = mu_r + sigma_r * random.normal(omega_rng, (10 * N, ))

   condition1 = np.where(spd > 0)[0]
   spd = spd[condition1]
   r = r[condition1]

   condition2 = np.where(r > 0)[0]
   spd = spd[condition2]
   r = r[condition2]
   omega = np.divide(spd, r)

   condition3 = np.where(omega <= 1.4)[0]
   spd = spd[condition3]
   omega = omega[condition3]

   spd = spd[:N]
   omega = omega[:N]

   return spd, omega

In [None]:
v, omega = init()
update = dynamics(v, omega)

state = {
    'chiral' : chiral,
    'key' : rng,
}

chiral_buffer = []

for i in ProgressIter(range(70000)):
  state = lax.fori_loop(0, 25, update, state)
  chiral_buffer += [state['chiral']]

display(render(box_size, chiral_buffer, name="supekar_12000"))