In [1]:
#@title (imports & utils)
import io
import base64
import time
from functools import partial
from typing import NamedTuple
from collections import namedtuple
import subprocess

import PIL
import numpy as np
import matplotlib.pylab as pl

from IPython.display import display, Image, HTML, clear_output
import ipywidgets as widgets

import jax
import jax.numpy as jp


def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def imshow(a, fmt='jpeg', display=display):
  return display(Image(data=imencode(a, fmt)))

def grab_plot(close=True):
  """Return the current Matplotlib figure as an image"""
  fig = pl.gcf()
  fig.canvas.draw()
  img = np.array(fig.canvas.renderer._renderer)
  a = np.float32(img[..., 3:]/255.0)
  img = np.uint8(255*(1.0-a) + img[...,:3] * a)  # alpha
  if close:
    pl.close()
  return img

def show_videofile(fn):
  b64 = base64.b64encode(open(fn, 'rb').read()).decode('utf8')
  s = f'''<video controls loop>
 <source src="data:video/mp4;base64,{b64}" type="video/mp4">
 Your browser does not support the video tag.</video>'''
  display(HTML(s))

class VideoWriter:
  def __init__(self, filename='_tmp.mp4', fps=30.0, show_on_finish=True):
    self.ffmpeg = None
    self.filename = filename
    self.fps = fps
    self.view = widgets.Output()
    self.last_preview_time = 0.0
    self.frame_count = 0
    self.show_on_finish = show_on_finish
    display(self.view)

  def add(self, img):
    img = np.asarray(img)
    h, w = img.shape[:2]
    if self.ffmpeg is None:
      self.ffmpeg = self._open(w, h)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.ffmpeg.stdin.write(img.tobytes())
    t = time.time()
    self.frame_count += 1
    if self.view and t-self.last_preview_time > 1.0:
       self.last_preview_time = t
       with self.view:
         clear_output(wait=True)
         imshow(img)
         print(self.frame_count)

  def __call__(self, img):
    return self.add(img)

  def _open(self, w, h):
    cmd = f'''ffmpeg -y -f rawvideo -vcodec rawvideo -s {w}x{h}
      -pix_fmt rgb24 -r {self.fps} -i - -pix_fmt yuv420p
      -c:v libx264 -crf 20 {self.filename}'''.split()
    return subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)

  def close(self):
    if self.ffmpeg:
        self.ffmpeg.stdin.close()
        self.ffmpeg.wait()
        self.ffmpeg = None
    if self.view:
      with self.view:
        clear_output()
      self.view.close()
      self.view = None

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.show_on_finish:
        self.show()

  def _ipython_display_(self):
    self.show()

  def show(self):
      self.close()
      show_videofile(self.filename)

# JAX utils

def vmap2(f):
  return jax.vmap(jax.vmap(f))

def norm(v, axis=-1, keepdims=False, eps=0.0):
  return jp.sqrt((v*v).sum(axis, keepdims=keepdims).clip(eps))

def normalize(v, axis=-1, eps=1e-20):
  return v/norm(v, axis, keepdims=True, eps=eps)


pl.rcParams.update({"axes.grid" : True})

In [2]:
# #@title (OG Lenia creatures diagram)
# # based on Bert Chan's original Lenia implementation
# # https://colab.research.google.com/github/Chakazul/Lenia/blob/master/Jupyter/Lenia.ipynb

# class Lenia(namedtuple('Lenia', 'R, peaks, mu, sigma')):
#   def rescale(p, x, n=4):
#     x = jp.repeat(jp.repeat(jp.array(x), n, axis=0), n, axis=1)
#     p = p._replace(R = p.R * n)
#     return p, x

#   def kernel_shell(p, r):
#     def kernel_core(r):
#       rm = jp.minimum(r, 1)
#       return (4 * rm * (1-rm))**4
#     k = len(p.peaks)
#     kr = k * r
#     peak = p.peaks[jp.minimum(jp.floor(kr).astype(int), k-1)]
#     return (r<1) * kernel_core(kr % 1) * peak

#   @jax.jit
#   def step(p, x, dt=0.05):
#     SIZE = x.shape[0]
#     MID = SIZE // 2
#     I = jp.array([jp.arange(SIZE),]*SIZE)
#     X = (I-MID) / p.R
#     Y = X.T
#     D = jp.sqrt(X**2 + Y**2)

#     kernel = p.kernel_shell(D)
#     kernel_FFT = jp.fft.fft2(kernel / jp.sum(kernel))
#     x_FFT = jp.fft.fft2(x)
#     potential = jp.roll(jp.real(jp.fft.ifft2(kernel_FFT * x_FFT)), MID, (0, 1))
#     delta = jp.maximum(0, 1 - (potential - p.mu)**2 / (p.sigma**2 * 9) )**4 * 2 - 1
#     return jp.maximum(0, jp.minimum(1, x + delta * dt))

#   @staticmethod
#   def render_world(x, vmin=0, vmax=1, title_1='', title_2='', sep_x=None, alpha_1 = 1.0, alpha_2 = 1.0):
#     SIZE = x.shape[0]
#     if sep_x is None:
#       sep_x = SIZE
#     fig = pl.figure(figsize=(jp.shape(x)[1]/80, jp.shape(x)[0]/80), dpi=80)
#     ax = fig.add_axes([0, 0, 1, 1])
#     ax.grid(False)
#     ax.text(sep_x//2, SIZE - 40, title_1, fontsize='xx-large', color='white', ha='center', va='center', alpha = alpha_1)
#     ax.text(SIZE + sep_x//2, SIZE - 40, title_2, fontsize='xx-large', color='white', ha='center', va='center', alpha = alpha_2)
#     ax.axvline(x=sep_x, linestyle='--', linewidth=4)
#     img = ax.imshow(x, cmap='jet', interpolation='none', aspect=1, vmin=vmin, vmax=vmax)
#     return grab_plot()

#   @staticmethod
#   def get_creatures():
#     p1 = Lenia(R=13, peaks=jp.array([1]), mu=0.15, sigma=0.014)
#     p2 = Lenia(R=13, peaks=jp.array([1]), mu=0.156, sigma=0.0224)
#     c1 = jp.array([[0,0,0,0,0,0,0.1,0.14,0.1,0,0,0.03,0.03,0,0,0.3,0,0,0,0],[0,0,0,0,0,0.08,0.24,0.3,0.3,0.18,0.14,0.15,0.16,0.15,0.09,0.2,0,0,0,0],[0,0,0,0,0,0.15,0.34,0.44,0.46,0.38,0.18,0.14,0.11,0.13,0.19,0.18,0.45,0,0,0],[0,0,0,0,0.06,0.13,0.39,0.5,0.5,0.37,0.06,0,0,0,0.02,0.16,0.68,0,0,0],[0,0,0,0.11,0.17,0.17,0.33,0.4,0.38,0.28,0.14,0,0,0,0,0,0.18,0.42,0,0],[0,0,0.09,0.18,0.13,0.06,0.08,0.26,0.32,0.32,0.27,0,0,0,0,0,0,0.82,0,0],[0.27,0,0.16,0.12,0,0,0,0.25,0.38,0.44,0.45,0.34,0,0,0,0,0,0.22,0.17,0],[0,0.07,0.2,0.02,0,0,0,0.31,0.48,0.57,0.6,0.57,0,0,0,0,0,0,0.49,0],[0,0.59,0.19,0,0,0,0,0.2,0.57,0.69,0.76,0.76,0.49,0,0,0,0,0,0.36,0],[0,0.58,0.19,0,0,0,0,0,0.67,0.83,0.9,0.92,0.87,0.12,0,0,0,0,0.22,0.07],[0,0,0.46,0,0,0,0,0,0.7,0.93,1,1,1,0.61,0,0,0,0,0.18,0.11],[0,0,0.82,0,0,0,0,0,0.47,1,1,0.98,1,0.96,0.27,0,0,0,0.19,0.1],[0,0,0.46,0,0,0,0,0,0.25,1,1,0.84,0.92,0.97,0.54,0.14,0.04,0.1,0.21,0.05],[0,0,0,0.4,0,0,0,0,0.09,0.8,1,0.82,0.8,0.85,0.63,0.31,0.18,0.19,0.2,0.01],[0,0,0,0.36,0.1,0,0,0,0.05,0.54,0.86,0.79,0.74,0.72,0.6,0.39,0.28,0.24,0.13,0],[0,0,0,0.01,0.3,0.07,0,0,0.08,0.36,0.64,0.7,0.64,0.6,0.51,0.39,0.29,0.19,0.04,0],[0,0,0,0,0.1,0.24,0.14,0.1,0.15,0.29,0.45,0.53,0.52,0.46,0.4,0.31,0.21,0.08,0,0],[0,0,0,0,0,0.08,0.21,0.21,0.22,0.29,0.36,0.39,0.37,0.33,0.26,0.18,0.09,0,0,0],[0,0,0,0,0,0,0.03,0.13,0.19,0.22,0.24,0.24,0.23,0.18,0.13,0.05,0,0,0,0],[0,0,0,0,0,0,0,0,0.02,0.06,0.08,0.09,0.07,0.05,0.01,0,0,0,0,0]])
#     c2 = jp.array([[0,0,0,0,0,0,0,0,0.003978,0.016492,0.004714,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0.045386,0.351517,0.417829,0.367137,0.37766,0.426948,0.431058,0.282864,0.081247,0,0,0,0,0,0],[0,0,0,0,0.325473,0.450995,0.121737,0,0,0,0.003113,0.224278,0.47101,0.456459,0.247231,0.071609,0.013126,0,0,0],[0,0,0,0.386337,0.454077,0,0,0,0,0,0,0,0.27848,0.524466,0.464281,0.242651,0.096721,0.038476,0,0],[0,0,0.258817,0.583802,0.150994,0,0,0,0,0,0,0,0.226639,0.548329,0.550422,0.334764,0.153108,0.087049,0.042872,0],[0,0.008021,0.502406,0.524042,0.059531,0,0,0,0,0,0,0.033946,0.378866,0.615467,0.577527,0.357306,0.152872,0.090425,0.058275,0.023345],[0,0.179756,0.596317,0.533619,0.162612,0,0,0,0,0.015021,0.107673,0.325125,0.594765,0.682434,0.594688,0.381172,0.152078,0.073544,0.054424,0.030592],[0,0.266078,0.614339,0.605474,0.379255,0.195176,0.16516,0.179148,0.204498,0.299535,0.760743,1,1,1,1,0.490799,0.237826,0.069989,0.043549,0.022165],[0,0.333031,0.64057,0.686886,0.60698,0.509866,0.450525,0.389552,0.434978,0.859115,0.94097,1,1,1,1,1,0.747866,0.118317,0.037712,0.006271],[0,0.417887,0.6856,0.805342,0.824229,0.771553,0.69251,0.614328,0.651704,0.843665,0.910114,1,1,0.81765,0.703404,0.858469,1,0.613961,0.035691,0],[0.04674,0.526827,0.787644,0.895984,0.734214,0.661746,0.670024,0.646184,0.69904,0.723163,0.682438,0.618645,0.589858,0.374017,0.30658,0.404027,0.746403,0.852551,0.031459,0],[0.130727,0.658494,0.899652,0.508352,0.065875,0.009245,0.232702,0.419661,0.461988,0.470213,0.390198,0.007773,0,0.010182,0.080666,0.17231,0.44588,0.819878,0.034815,0],[0.198532,0.810417,0.63725,0.031385,0,0,0,0,0.315842,0.319248,0.321024,0,0,0,0,0.021482,0.27315,0.747039,0,0],[0.217619,0.968727,0.104843,0,0,0,0,0,0.152033,0.158413,0.114036,0,0,0,0,0,0.224751,0.647423,0,0],[0.138866,1,0.093672,0,0,0,0,0,0.000052,0.015966,0,0,0,0,0,0,0.281471,0.455713,0,0],[0,1,0.145606,0.005319,0,0,0,0,0,0,0,0,0,0,0,0.016878,0.381439,0.173336,0,0],[0,0.97421,0.262735,0.096478,0,0,0,0,0,0,0,0,0,0,0.013827,0.217967,0.287352,0,0,0],[0,0.593133,0.2981,0.251901,0.167326,0.088798,0.041468,0.013086,0.002207,0.009404,0.032743,0.061718,0.102995,0.1595,0.24721,0.233961,0.002389,0,0,0],[0,0,0.610166,0.15545,0.200204,0.228209,0.241863,0.243451,0.270572,0.446258,0.376504,0.174319,0.154149,0.12061,0.074709,0,0,0,0,0],[0,0,0.354313,0.32245,0,0,0,0.151173,0.479517,0.650744,0.392183,0,0,0,0,0,0,0,0,0],[0,0,0,0.329339,0.328926,0.176186,0.198788,0.335721,0.534118,0.549606,0.361315,0,0,0,0,0,0,0,0,0],[0,0,0,0,0.090407,0.217992,0.190592,0.174636,0.222482,0.375871,0.265924,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0.050256,0.235176,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0.180145,0.132616,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0.092581,0.188519,0.118256,0,0,0,0]])
#     p1, c1 = p1.rescale(c1, n=6)
#     p2, c2 = p2.rescale(c2, n=6)
#     return (p1, c1), (p2, c2)

# def lenia_universes_diagram(SIZE=1<<8):
#   (p1, c1), (p2, c2) = Lenia.get_creatures()
#   MID = int(SIZE / 2)
#   titles = dict(title_1="Glider's World\n$\mu=0.15, \sigma=0.014$",
#                 title_2="Rotator's World\n$\mu=0.156, \sigma=0.0224$")

#   x_1 = x_2 = jp.zeros((SIZE, SIZE))

#   def center(x):
#     CoM = ((jp.mgrid[0:SIZE:1, 0:SIZE:1] * x).sum(axis=(1,2))/jp.sum(x)).astype(jp.int32)
#     shift = jp.array([SIZE//2, SIZE//2]) - CoM
#     return jp.roll(x, shift, axis=(0, 1))

#   # paste creature 1
#   cs = c1.shape
#   x_1 = x_1.at[SIZE//2:SIZE//2 + cs[0], SIZE//2 - cs[1]:SIZE//2].set(c1)
#   x_1 = center(x_1)

#   # paste creature 2
#   cs = c2.shape
#   x_2 = x_2.at[(SIZE - cs[0])//2:(SIZE + cs[0])//2, (SIZE - cs[1])//2:(SIZE + cs[1])//2].set(c2)

#   with VideoWriter(fps=15) as vid:
#     for i in range(100):
#       x = jp.concatenate((x_1, x_2), axis=1)
#       vid(Lenia.render_world(x, **titles))
#       for _ in range(2):
#         x_1 = center(p1.step(x_1))
#         x_2 = p2.step(x_2)
#     for i in range(60):
#       # move seperator
#       scale = 1 - i/60
#       vid(Lenia.render_world(x, sep_x=SIZE*scale, alpha_1=2.5*max(scale-0.6, 0), **titles))
#     for i in range(100):
#       x = jp.concatenate((x_1, x_2), axis=1)
#       vid(Lenia.render_world(x, sep_x=0, alpha_1=0.0, **titles))
#       for _ in range(2):
#         x_1 = center(p2.step(x_1))
#         x_2 = p2.step(x_2)
#     for i in range(60):
#       # move seperator
#       scale = i/60
#       vid(Lenia.render_world(x, sep_x=SIZE*2*scale, alpha_1=2*max(scale-0.5, 0), alpha_2=2*max(0.5-scale, 0)))
#     for i in range(100):
#       x = jp.concatenate((x_1, x_2), axis=1)
#       vid(Lenia.render_world(x, sep_x=2*SIZE, alpha_2=0.0, **titles))
#       for _ in range(2):
#         x_1 = center(p1.step(x_1))
#         x_2 = p1.step(x_2)
# lenia_universes_diagram()

In [3]:
Params = namedtuple('Params', 'mu_k sigma_k w_k mu_g sigma_g c_rep')
Fields = namedtuple('Fields', 'U G R E')

def peak_f(x, mu, sigma):
  return jp.exp(-((x-mu)/sigma)**2)

def fields_f(p: Params, points, x):
  r = jp.sqrt(jp.square(x-points).sum(-1).clip(1e-10))
  U = peak_f(r, p.mu_k, p.sigma_k).sum()*p.w_k
  G = peak_f(U, p.mu_g, p.sigma_g)
  R = p.c_rep/2 * ((1.0-r).clip(0.0)**2).sum()
  return Fields(U, G, R, E=R-G)

def motion_f(params, points):
  grad_E = jax.grad(lambda x : fields_f(params, points, x).E)
  return -jax.vmap(grad_E)(points)

In [4]:
#@title (show_lenia)
import PIL.ImageFont, PIL.ImageDraw

def lerp(x, a, b):
  return jp.float32(a)*(1.0-x) + jp.float32(b)*x
def cmap_e(e):
  return 1.0-jp.stack([e, -e], -1).clip(0) @ jp.float32([[0.3,1,1], [1,0.3,1]])
def cmap_ug(u, g):
  vis = lerp(u[...,None], [0.1,0.1,0.3], [0.2,0.7,1.0])
  return lerp(g[...,None], vis, [1.17,0.91,0.13])

@partial(jax.jit, static_argnames=['w', 'show_UG', 'show_cmap'])
def show_lenia(params, points, extent, w=400, show_UG=False, show_cmap=True):
  xy = jp.mgrid[-1:1:w*1j, -1:1:w*1j].T*extent
  e0 = -peak_f(0.0, params.mu_g, params.sigma_g)
  f = partial(fields_f, params, points)
  fields = vmap2(f)(xy)
  r2 = jp.square(xy[...,None,:]-points).sum(-1).min(-1)
  points_mask = (r2/0.02).clip(0, 1.0)[...,None]
  vis = cmap_e(fields.E-e0) * points_mask
  if show_cmap:
    e_mean = jax.vmap(f)(points).E.mean()
    bar = np.r_[0.5:-0.5:w*1j]
    bar = cmap_e(bar) * (1.0-peak_f(bar, e_mean-e0, 0.005)[:,None])
    vis = jp.hstack([vis, bar[:,None].repeat(16, 1)])
  if show_UG:
    vis_u = cmap_ug(fields.U, fields.G)*points_mask
    if show_cmap:
      u = np.r_[1:0:w*1j]
      bar = cmap_ug(u, peak_f(u, params.mu_g, params.sigma_g))
      bar = bar[:,None].repeat(16, 1)
      vis_u = jp.hstack([bar, vis_u])
    vis = jp.hstack([vis_u, vis])
  return vis

fontpath = pl.matplotlib.get_data_path()+'/fonts/ttf/DejaVuSansMono.ttf'
pil_font = PIL.ImageFont.truetype(fontpath, size=16)

def text_overlay(img, text, pos=(20,10), color=(255,255,255)):
  img = np2pil(img)
  draw = PIL.ImageDraw.Draw(img)
  draw.text(pos, text, fill=color, font=pil_font)
  return img

def animate_lenia(params, tracks, rate=10, slow_start=0, w=400, show_UG=True,
                  name='_tmp.mp4', text=None, vid=None, bar_len=None,
                  bar_ofs=0, extent=None):
  if vid is None:
    vid = VideoWriter(fps=60, filename=name)
  if extent is None:
    extent = jp.abs(tracks).max()*1.2
  if bar_len is None:
    bar_len = len(tracks)
  for i, points in enumerate(tracks):
    if not (i<slow_start or i%rate==0):
      continue
    img = show_lenia(params, points, extent, w=w, show_UG=show_UG)
    bar = np.linspace(0, bar_len, img.shape[1])
    bar = (0.5+(bar>=i+bar_ofs)[:,None]*jp.ones(3)*0.5)[None].repeat(2, 0)
    frame = jp.vstack([img, bar])
    if text is not None:
      frame = text_overlay(frame, text)
    vid(frame)
  return vid

In [5]:
params = Params(mu_k=4.0, sigma_k=1.0, w_k=0.022, mu_g=0.6, sigma_g=0.15, c_rep=1.0)
key = jax.random.PRNGKey(20)
points0 = (jax.random.uniform(key, [200, 2])-0.5)*12.0
dt = 0.1

def odeint_euler(f, params, x0, dt, n):
  def step_f(x, _):
    x = x+dt*f(params, x)
    return x, x
  return jax.lax.scan(step_f, x0, None, n)[1]

rotor_story = odeint_euler(motion_f, params, points0, dt, 10000)
animate_lenia(params, rotor_story, name='rotor.mp4')

Output()

FileNotFoundError: [Errno 2] No such file or directory: 'ffmpeg'