In [135]:
#@title Imports & Utils


# @title Utilities

import base64
import dataclasses
import json

from IPython.display import HTML, JSON
import google.colab.output
import imageio.v3
import IPython
import IPython.display
import matplotlib.pyplot
import numpy
import seaborn
import torch

from beignet.func.__dataclass import _dataclass
import beignet.func

seaborn.set_style(style="white")

BACKGROUND_COLOR = [0.0] * 3

SIMULATION_INDEX = 0

TYPE_DIMENSIONS = {
    "color": 2,
    "diameter": 1,
    "neighbor_index": 2,
    "position": 2,
    "size": 1,
}


def make_from_image(filename, size_in_pixels):
    position = []
    angle = []
    color = []

    img = imageio.v3.imread(filename)

    scale = 2 ** (1 / 6)

    ratio = numpy.sqrt(1 - 0.25)

    for i, y in enumerate(range(0, img.shape[0], size_in_pixels)):
        for x in range(0, img.shape[1], size_in_pixels):
            r, g, b, a = img[y, x]

            if a == 255:
                hshift = size_in_pixels * (i % 2) / 2.0

                position += [
                    [
                        scale * (x + hshift) / size_in_pixels,
                        scale * (img.shape[0] - y) / size_in_pixels * ratio,
                    ],
                ]

                color += [[r / 255, g / 255, b / 255]]

    img_size = torch.tensor(img.shape[:2]).T / size_in_pixels * scale

    box_size = torch.max(img_size) * 1.5

    position = torch.tensor(position).to(torch.float64)
    position = position + box_size / 2.0 - img_size / 2

    color = torch.tensor(color).to(torch.float64)

    return box_size, position, color


def static_field():
    return dataclasses.field(metadata={"static": True})


def to_np(*xs):
    return [numpy.array(x) if isinstance(x, numpy.ndarray) else x for x in xs]


@_dataclass
class Disk:
    position: numpy.ndarray
    size: numpy.ndarray
    color: numpy.ndarray
    count: int = dataclasses.field(metadata={"static": True})

    def __init__(self, position, diameter=1.0, color=None):
        if color is None:
            color = numpy.array([0.8, 0.8, 1.0])

        position, diameter, color = to_np(position, diameter, color)

        object.__setattr__(self, "position", position)
        object.__setattr__(self, "size", diameter)
        object.__setattr__(self, "color", color)
        object.__setattr__(self, "count", position.shape[-2])

    def __repr__(self):
        return "Disk"


@_dataclass
class Sphere:
    position: numpy.ndarray
    size: numpy.ndarray
    color: numpy.ndarray
    count: int

    def __init__(self, position, diameter=1.0, color=None):
        if color is None:
            color = numpy.array([0.8, 0.8, 1.0])

        position, diameter, color = to_np(position, diameter, color)

        object.__setattr__(self, "position", position)
        object.__setattr__(self, "size", diameter)
        object.__setattr__(self, "color", color)
        object.__setattr__(self, "count", position.shape[-2])

    def __repr__(self):
        return "Sphere"


def _encode(R):
    dtype = R.dtype

    if dtype == numpy.float64:
        dtype = numpy.float32

    if dtype == numpy.int64:
        dtype = numpy.int32

    dtype = numpy.float32

    return base64.b64encode(numpy.array(R, dtype).tobytes()).decode("utf-8")


def _to_json(data):
    try:
        return JSON(data=data)
    except:
        return JSON(data=json.dumps(data))


def render(
    box_size,
    geometries,
    buffer_size=None,
    background_color=None,
    resolution=None,
    frame_rate=None,
    verbose=False,
):
    """Creates a rendering front-end along with callbacks in the host program.

    Args:
      box_size: A float or an array of shape `(spatial_dimension,)`. Specifies
        the size of the simulation volume. Used to position the camera.
      geometries: A dictionary containing names paired with geometric objects such
        as Disk, Sphere, or Bond.
      buffer_size: The maximum number of timesteps to send to the font-end in a
        single call.
      background_color: An array of shape (3,) specifying the background color of
        the visualization.
      resolution: The resolution of the renderer.
      frame_rate: An optional integer specifying the target frames-per-second
        for the renderer.
      verbose: Specifies whether or not the client should emit information and
        error messages. Useful for debugging visualizations, but adds clutter.
    """
    global SIMULATION_INDEX

    simulation_index = SIMULATION_INDEX

    frame_count = None
    dimension = None

    if not isinstance(geometries, dict):
        geometries = {"all": geometries}

    for geometry in geometries.values():
        if hasattr(geometry, "position"):
            assert dimension is None or geometry.position.shape[-1] == dimension

            dimension = geometry.position.shape[-1]

            if geometry.position.ndim == 3:
                assert frame_count is None or frame_count == geometry.position.shape[0]

                frame_count = geometry.position.shape[0]

    assert dimension is not None

    if isinstance(box_size, (numpy.ndarray, numpy.ndarray)):
        if box_size.shape:
            assert box_size.shape == (dimension,)

            box_size = list(box_size)
        else:
            box_size = [
                float(box_size),
            ] * dimension
    elif isinstance(box_size, float) or isinstance(box_size, int):
        box_size = [
            box_size,
        ] * dimension

    def get_metadata():
        metadata = {
            "box_size": box_size,
            "dimension": dimension,
            "geometry": [k for k in geometries.keys()],
            "simulation_idx": simulation_index,
        }

        if frame_count is not None:
            metadata["frame_count"] = frame_count

        if buffer_size is not None:
            metadata["buffer_size"] = buffer_size

        if background_color is not None:
            metadata["background_color"] = background_color

        if resolution is not None:
            metadata["resolution"] = resolution

        if frame_rate is not None:
            metadata["frame_rate"] = frame_rate

        if verbose:
            metadata["verbose"] = True

        return _to_json(metadata)

    google.colab.output.register_callback(
        "GetSimulationMetadata",
        get_metadata,
    )

    def get_dynamic_geometry_metadata(name):
        assert name in geometries

        geom = geometries[name]
        geom_dict = dataclasses.asdict(geom)

        geom_metadata = {
            "fields": {},
            "shape": str(geom),
        }

        for field in geom_dict:
            if isinstance(geom_dict[field], list):
                geom_dict[field] = numpy.array(geom_dict[field])

            if not isinstance(geom_dict[field], numpy.ndarray):
                geom_metadata[field] = geom_dict[field]

                continue

            if len(geom_dict[field].shape) == TYPE_DIMENSIONS[field] + 1:
                geom_metadata["fields"][field] = "dynamic"
            elif len(geom_dict[field].shape) == TYPE_DIMENSIONS[field]:
                geom_metadata["fields"][field] = "static"
            elif len(geom_dict[field].shape) == TYPE_DIMENSIONS[field] - 1:
                geom_metadata["fields"][field] = "global"

        return _to_json(geom_metadata)

    google.colab.output.register_callback(
        f"GetGeometryMetadata{SIMULATION_INDEX}",
        get_dynamic_geometry_metadata,
    )

    def get_array_chunk(name, field, offset, size):
        assert name in geometries

        geom = dataclasses.asdict(geometries[name])

        assert field in geom

        array = geom[field]

        if isinstance(array, list):
            array = numpy.array(array)

        return _to_json({"array_chunk": _encode(array[offset : (offset + size)])})

    google.colab.output.register_callback(
        f"GetArrayChunk{SIMULATION_INDEX}",
        get_array_chunk,
    )

    def get_array(name, field):
        assert name in geometries

        geom = dataclasses.asdict(geometries[name])

        assert field in geom

        array = geom[field]

        if isinstance(array, list):
            array = numpy.array(array)

        return _to_json({"array": _encode(array)})

    google.colab.output.register_callback(
        f"GetArray{SIMULATION_INDEX}",
        get_array,
    )

    SIMULATION_INDEX = SIMULATION_INDEX + 1

    IPython.display.display(
        HTML(
            url=(
                "https://raw.githubusercontent.com/google/jax-md/main/jax_md/colab_tools/visualization.html"
            ),
        ),
    )

Obtaining file:///Users/isaacsoh/com/github/Genentech/beignet
  Installing build dependencies ... [?25l/^C
[?25canceled
[31mERROR: Operation cancelled by user[0m[31m
[0m

ModuleNotFoundError: No module named 'google'

# Sand Castle

In this demo we simulate a sand castle and then demolish it using a projectile.

## Load the sand castle

In [93]:
box, positions, colors = make_from_image('sand_castle.png', 24)

  img = imageio.imread(filename)


In [94]:
render(box, Disk(positions, color=colors))

There are 1616 grains.


In [96]:
print(f'There are {len(positions)} grains.')

## Spaces

In [97]:
import torch
from beignet.func import space

displacement_fn, shift_fn = space(box, parallelepiped=False, remapped=False)

tensor([73.9422, 62.5549])

In [98]:
positions[0]

tensor([-27.5003,  41.7995])

In [99]:
displacement_fn(positions[0], positions[-1])

tensor([83.9422, 62.5549])

In [100]:
shift_fn(positions[0], torch.tensor([10.0, 0.0]))

## Energy

"Energy" in Physics plays a similar role to "Loss" in machine learning.

Write down an energy function between two grains of sand, $\epsilon(r)$.

The total energy will be the sum of all pairs of energies.

$$E = \sum_{i,j} \epsilon(r_{ij})$$

where $r_{ij}$ is the distance between grain $i$ and grain $j$.


We want to model wet sand:

*   Grains are hard (no interpenetration).
*   Grains stick together a little bit.
*   Grains far away from one another don't notice each other.

In [101]:
import matplotlib.pyplot as plt
from beignet import lennard_jones_potential

rs = torch.linspace(0.5, 2.5, steps=50)
plt.plot(rs, lennard_jones_potential(rs, sigma=1, epsilon=1))

plt.ylim([-1, 1])
plt.xlim([0, 2.5])
plt.xlabel('$r_{ij}$')
plt.ylabel('$\\epsilon$')

tensor(-4977.0781)

In [102]:
from beignet._lennard_jones_pair_potential import lennard_jones_pair_potential

sand_energy = lennard_jones_pair_potential(displacement_fn)

sand_energy(positions)

## Simulate

In [103]:
simulation_steps = 10000
write_every = 50

In [104]:
from beignet.func._simulate import ensemble

setup_fn, step_fn = ensemble(sand_energy, shift_fn, step_size=5e-3, temperature=0.0, thermostat="Langevin", friction=1e-2)

sand = setup_fn(positions, temperature=0.0)

In [105]:
trajectory = []

for i in range(simulation_steps):
  if i % write_every == 0:
    trajectory += [sand.positions]

  sand = step_fn(sand)

trajectory = torch.stack(trajectory)

In [107]:
render(box, Disk(trajectory, color=colors))

## Simulate slightly faster...

In [108]:
import torch

def simulation_fn(i, sand_trajectory, write_every, step_fn):
    sand, trajectory = sand_trajectory

    trajectory[i] = sand.positions
    for _ in range(write_every):
        sand = step_fn(sand)

    return sand, trajectory

In [109]:
write_steps = simulation_steps // write_every
n = positions.shape[0]

sand = setup_fn(positions, temperature=0.0)
trajectory = torch.zeros((write_steps, n, 2))
for i in range(write_steps):
    sand, trajectory = simulation_fn(i, (sand, trajectory), write_every, step_fn)

In [110]:
render(box, Disk(trajectory, color=colors))

SyntaxError: invalid syntax. Perhaps you forgot a comma? (715785472.py, line 1)

## Let's blow it up!

### The projectile

In [111]:
projectile = jnp.array([1.0, box / 3.0])

radius = jnp.array(2.0)
strength = 1000.0
velocity = jnp.array([3e-2, 0.0])

Model the projectile by adding a term to the energy,

$$E = \sum_{i,j}\epsilon(r_{ij}) + \sum_i \epsilon_p(r_{ip})$$

where $r_{ip}$ is the distance between grain $i$ and the projectile.

Want the projectile to only repel the sand (no attraction).

In [120]:
from jax_md import energy

rs = jnp.linspace(0.5, 2.5)
plt.plot(rs, energy.lennard_jones(rs))
plt.plot(rs, energy.soft_sphere(rs, epsilon=strength))

plt.ylim([-1, 10])
plt.xlim([0, 2.5])
plt.xlabel('$r_{ij}$')
plt.ylabel('$\\epsilon$')

In [122]:
def projectile_energy(sand, projectile):
  distance = jnp.linalg.norm(sand - projectile, axis=-1)
  e = energy.soft_sphere(distance, sigma=radius + 1.0, epsilon=strength)
  return jnp.sum(e)

def total_energy(sand, projectile, **kwargs):
  return sand_energy(sand) + projectile_energy(sand, projectile)

### Run the simulation

In [123]:
from jax_md import dataclasses

@dataclasses.dataclass
class SandCastle:
  sand: simulate.NVTLangevinState
  projectile: jnp.ndarray

In [124]:
simulation_steps = 10000
write_every = 50
write_steps = simulation_steps // write_every

In [125]:
from jax_md import simulate

init_fn, step_fn = simulate.nvt_langevin(total_energy, shift_fn, dt=5e-3, kT=0.0)

SyntaxError: unterminated string literal (detected at line 3) (2448750232.py, line 3)

In [126]:
from jax import lax

def simulation_fn(i, state_trajectory):
  state, traj = state_trajectory

  traj = SandCastle(
      traj.sand.at[i].set(state.sand.position),
      traj.projectile.at[i].set(state.projectile)
  )

  def total_step_fn(_, state):
    return SandCastle(
        step_fn(state.sand, projectile=state.projectile),
        state.projectile + velocity
    )

  state = lax.fori_loop(0, write_every, total_step_fn, state)

  return state, traj

  img = imageio.imread(filename)


In [127]:
n = positions.shape[0]

state = SandCastle(
    init_fn(key, positions, projectile=projectile),
    projectile
)
trajectory = SandCastle(
    jnp.zeros((write_steps, n, 2)),
    jnp.zeros((write_steps, 2))
)

state, trajectory = lax.fori_loop(0, write_steps, simulation_fn, (state, trajectory))

25961

In [128]:
renderer.render(
    box,
    {
        'sand': renderer.Disk(trajectory.sand, color=colors),
        'projectile': renderer.Disk(trajectory.projectile[:, None, :],
                                    diameter=radius * 2)
    }
)

## Scaling Up

So far at each step we have been computing the interaction between every pair of grains.

But grains that are far apart don't affect each other.

In [129]:
box, positions, colors = make_from_image('sand_castle.png', 6)

  particle_index = torch.tensor(positions / unit_size, dtype=torch.int32)


cell_capacity: 15


In [130]:
len(positions)

torch.Size([2, 472671])

In [131]:
from jax_md.colab_tools import renderer

renderer.render(box, renderer.Disk(positions, color=colors))

In [132]:
displacement_fn, shift_fn = space.periodic(box)

### Neighbor lists

In [133]:
neighbor_fn, sand_energy = energy.lennard_jones_neighbor_list(displacement_fn, box)

In [134]:
nbrs = neighbor_fn.allocate(positions)

In [86]:
nbrs.idx.shape

In [None]:
def total_energy(sand, projectile, neighbor, **kwargs):
  return sand_energy(sand, neighbor) + projectile_energy(sand, projectile)

### Simulation

In [None]:
simulation_steps = 30000
write_every = 400
write_steps = simulation_steps // write_every

projectile = jnp.array([1.0, box / 3.0])
radius = jnp.array(8.0)

In [None]:
from jax_md import partition

@dataclasses.dataclass
class SandCastle:
  sand: simulate.NVTLangevinState
  projectile: jnp.ndarray
  neighbor: partition.NeighborList

In [None]:
from jax_md import simulate

init_fn, step_fn = simulate.nvt_langevin(total_energy, shift_fn, dt=5e-3, kT=0.0, gamma=1e-2)

In [None]:
from jax import lax

def simulation_fn(i, state_trajectory):
  state, traj = state_trajectory

  traj = SandCastle(
     traj.sand.at[i].set(state.sand.position),
     traj.projectile.at[i].set(state.projectile),
     None
  )

  def total_step_fn(_, state):
    sand = step_fn(state.sand,
                   projectile=state.projectile,
                   neighbor=state.neighbor)
    projectile = state.projectile + velocity
    neighbor = state.neighbor.update(state.sand.position)
    return SandCastle(sand, projectile, neighbor)

  state = lax.fori_loop(0, write_every, total_step_fn, state)

  return state, traj

In [None]:
n = positions.shape[0]

state = SandCastle(
    init_fn(random.PRNGKey(0), positions, projectile=projectile, neighbor=nbrs),
    projectile,
    nbrs
)
trajectory = SandCastle(
    jnp.zeros((write_steps, n, 2)),
    jnp.zeros((write_steps, 2)),
    None
)

state, trajectory = lax.fori_loop(0, write_steps, simulation_fn, (state, trajectory))

In [None]:
state.neighbor.did_buffer_overflow

In [None]:
renderer.render(
    box,
    {
        'sand': renderer.Disk(trajectory.sand, color=colors),
        'projectile': renderer.Disk(trajectory.projectile[:, None, :],
                                    diameter=radius * 2)
    },
    buffer_size=10
)