<a href="https://colab.research.google.com/github/Nishant-Ramakuru/Inference-based-GNNS/blob/main/Predicting_Collective_Dynamics_w_GNN_Simulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Simulator used to generate collective dynamics

This simulator makes use of a JAX-MD library to create simple simulations of collective dynamics of active agents using GPU hardware and is fully differentiable. The first model uses a simple interaction potential such that agents maintain a minimal distance from each other and have an individual self-propulsion force. Additionally agents move stochastically such that their orientation is randomly determined. 

## TODO: 

- Output neighbour lists (as adjacency lists)
- Extend to include alignment
- Extend to include populations of agents
- Extend to include novel interaction dynamics 

In [None]:
!pip install jax
!pip install jax-md
!pip install tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pickle

import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True) 
import jax.numpy as np 
from jax import random 
from jax import jit  
from jax import lax  
from jax import vmap 

import time 
from tqdm import tqdm
from jax_md import space, smap, energy, minimize, quantity, simulate, partition,dataclasses
from jax_md import util
from collections import namedtuple 
from functools import partial 
from typing import Any, Callable, TypeVar, Union, Tuple, Dict, Optional



In [None]:
def interaction_potential(dR, J_avoid, D_avoid, alpha):
    dr = space.distance(dR) / D_avoid
    return np.where(dr < 1., 
                    J_avoid / alpha * (1 - dr) ** alpha, 
                    0.)

def energy_fn(state):
    E_interact = partial(interaction_potential, J_avoid=25., D_avoid=30., alpha=3.)
    E_interact = vmap(vmap(E_interact))

    dR = space.map_product(displacement_fn)(state,state)

    return 0.5 * np.sum(E_interact(dR))

In [None]:
@dataclasses.dataclass
class ActiveBrownianState:
    """A tuple containing state information for Brownian dynamics.

    Attributes:
    position: The current position of the particles. An ndarray of floats with
      shape `[n, spatial_dimension]`.
    mass: The mass of particles. Will either be a float or an ndarray of floats
      with shape `[n]`.
    rng: The current state of the random number generator.
    """
    position: util.Array
    theta: util.Array
    rng: util.Array


In [None]:
T = TypeVar('T')
InitFn = Callable[..., T]
ApplyFn = Callable[[T], T]

def activeBrownian(energy_or_force: Callable[..., util.Array],
    shift: space.ShiftFn,
    dt: float,
    tau: float,
    v0: float=0.1) -> Tuple[InitFn, ApplyFn]:
    """Simulation of active Brownian dynamics.

    Simulates active Brownian dynamics which are synonymous with the overdamped
    regime of Langevin dynamics with self-propulsion force. 
    Args:
    energy_or_force: A function that produces either an energy or a force from
    a set of particle positions specified as an ndarray of shape
    `[n, spatial_dimension]`.
    shift_fn: A function that displaces positions, `R`, by an amount `dR`.
    Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`.
    dt: Floating point number specifying the timescale (step size) of the
    simulation.
    tau: Floating point number specifying persistence timescale.
    v0: Floating point number specifying active force.

    Returns:
        See above.
    """
    # convert energy functions to forces
    dt, tau, v0 = util.static_cast(dt, tau, v0)
    force_fn = quantity.canonicalize_force(energy_or_force)
    
    def init_fn(R, theta, key):
        return ActiveBrownianState(R, theta, key)  # pytype: disable=wrong-arg-count

    @vmap
    def normal(theta):
        return np.array([np.cos(theta), np.sin(theta)])

    def apply_fn(_, state, **kwargs):
    
        # Combine movement functionality into a `move` function.
        R, theta, key = dataclasses.astuple(state)
        
        key, split = random.split(key)
        eta = random.normal(split, theta.shape, theta.dtype)

        F_int = force_fn(R)
        
        dR = dt * (v0*normal(theta)+ F_int)
        R = shift(R, dR, **kwargs)
        
        theta = theta + (dt*util.f32(2)/tau)**(1/2) * eta

        return ActiveBrownianState(R,theta,key)

    return init_fn, apply_fn

In [None]:
# Create RNG state to draw random numbers
key = random.PRNGKey(0)

# Simulation Parameters:
poly = 0.3
box_size = 55.0
Nparticles = 650 
dt = 1e-2
tau = 1
v0 = 0.1
dim = 2


In [None]:
# Define periodic boundary conditions.
displacement_fn, shift_fn = space.periodic(box_size)

# Define simulation function
init_fn, apply_fn = activeBrownian(energy_fn, shift_fn, dt, tau, v0)


In [None]:
# Initialize the particle positions, theta
rng, R_rng, theta_rng = random.split(key, 3)

R = box_size * random.uniform(R_rng, (Nparticles, dim))
theta = random.uniform(theta_rng, (Nparticles,), maxval= 2.* np.pi)


In [None]:
state = init_fn(R, theta, key)


In [None]:
state_buffer = []

t0 = time.time() 
for i in tqdm(range(2)):
    state = lax.fori_loop(0, 200, apply_fn, (state))
    state_buffer.append(state.position)

tend =time.time()

100%|██████████| 2/2 [00:07<00:00,  3.92s/it]
