In [1]:
%load_ext autoreload
%autoreload 2
!hostname
!pwd
import os, sys
print(sys.executable)
# os.environ['CUDA_VISIBLE_DEVICES'] = "4"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
sys.path.append(os.path.abspath(".."))

slurm0-gpu1nodeset-0
/home/akarsh_sakana_ai/nca-alife/src/notebooks
/home/akarsh_sakana_ai/miniconda3/envs/nca-alife-jax/bin/python


In [2]:
import os, sys, glob, pickle
from functools import partial

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat

In [3]:
import jax
from jax import numpy as jnp
from jax.random import split

import util

In [4]:
from create_sim import create_sim, rollout_and_embed_simulation, rollout_simulation, FlattenSimulationParameters
from models.models_boids import Boids
from models.models_boids_mushy import MushyBoids
import imageio

In [5]:
from models.models_gol import GameOfLife

In [6]:
sim = GameOfLife(grid_size=128)

In [7]:
colors = 'ffadad-ffd6a5-fdffb6-caffbf-9bf6ff-a0c4ff-bdb2ff-ffc6ff-448aff-1565c0-009688-8bc34a-ffc107-ff9800-f44336-ad1457'
print(colors)
colors = colors.split('-')
np.random.seed(0)
np.random.shuffle(colors)
colors = '-'.join(colors)
print(colors)

ffadad-ffd6a5-fdffb6-caffbf-9bf6ff-a0c4ff-bdb2ff-ffc6ff-448aff-1565c0-009688-8bc34a-ffc107-ff9800-f44336-ad1457
ffd6a5-bdb2ff-448aff-1565c0-ff9800-9bf6ff-fdffb6-f44336-009688-ffc6ff-ad1457-8bc34a-caffbf-ffadad-a0c4ff-ffc107


In [8]:
from matplotlib import colors as mcolors
class MushyGoL():
    def __init__(self, k_sims=4, grid_size=128, double_step=True,
                 colors='448aff-1565c0-009688-8bc34a-ffc107-ff9800-f44336-ad1457-448aff-1565c0-009688-8bc34a-ffc107-ff9800-f44336-ad1457'):
        assert k_sims == 4 or k_sims == 9 or k_sims == 16 or k_sims == 25 or k_sims == 36
        self.k_sims = k_sims
        self.sqrt_ksims = int(jnp.sqrt(k_sims))
        self.sim = GameOfLife(grid_size=grid_size)
        assert self.k_sims <= len(colors.split('-'))
        self.species_colors = jnp.array([mcolors.to_rgb(f"#{c}") for c in colors.split('-')])[:self.k_sims]
        self.double_step = double_step

    def default_params(self, rng):
        return jax.vmap(self.model_boids.default_params, in_axes=(0,))(split(rng, self.k_sims))

    def init_state(self, rng, params):
        state = self.sim.init_state(rng, params[0])
        rule_state = jnp.arange(self.k_sims).reshape(self.sqrt_ksims, self.sqrt_ksims)
        rule_state = repeat(rule_state, "x y -> (x W) (y H)", W=self.sim.grid_size//self.sqrt_ksims, H=self.sim.grid_size//self.sqrt_ksims)
        return dict(state=state, rule_state=rule_state)

    def step_state(self, rng, state, params):
        state, rule_state = state['state'], state['rule_state']

        def step_fn(rng, state, params):
            state = self.sim.step_state(rng, state, params)
            if self.double_step:
                state = self.sim.step_state(rng, state, params)
            return state
        state = jax.vmap(step_fn, in_axes=(None, None, 0))(rng, state, params)
        state = rearrange(state, "D H W -> H W D")
        def index_fn(states, rule_idx):
            return states[rule_idx]
        state = jax.vmap(jax.vmap(index_fn))(state, rule_state)

        # CHANGING DYNANMICS CODE
        def get_neighbors(x):
            x = jnp.pad(x, pad_width=1, mode='wrap')
            neighs = jnp.stack([x[:-2, :-2], x[:-2, 1:-1], x[:-2, 2:], x[1:-1, :-2], x[1:-1, 2:], x[2:, :-2], x[2:, 1:-1], x[2:, 2:]], axis=-1)
            return neighs
        state_neighs = get_neighbors(state)
        rule_state_neighs = get_neighbors(rule_state)

        def get_rule_idx(rng, state, rule_state, state_neighs, rule_state_neighs):
            state_neighs = jax.random.permutation(rng, state_neighs)
            rule_state_neighs = jax.random.permutation(rng, rule_state_neighs)

            rule_state_2 = rule_state_neighs[jnp.argmax(state_neighs)]
            # only change rule_state if state is dead and there is a living neighbor
            return jax.lax.select((state==0)& (state_neighs.sum()>0), rule_state_2, rule_state)

        state_neighs = rearrange(state_neighs, "H W D -> (H W) D")
        rule_state_neighs = rearrange(rule_state_neighs, "H W D -> (H W) D")
        rule_state = jax.vmap(get_rule_idx)(split(rng, len(state_neighs)), state.flatten(), rule_state.flatten(), state_neighs, rule_state_neighs)
        rule_state = rule_state.reshape(*state.shape)
        return dict(state=state, rule_state=rule_state)
    
    def render_state(self, state, params, img_size=None):
        state, rule_state = state['state'], state['rule_state']
        img = repeat(state.astype(float), "H W -> H W 3")
        img = img * self.species_colors[rule_state]
        if img_size is not None:
            img = jax.image.resize(img, (img_size, img_size, 3), method='nearest')
        return img


In [9]:
params = jnp.array([7291, 7787, 55871, 15923, 2579, 5691, 47999, 55615, 56127, 8059, 55903, 133519, 55359, 3731, 36703, 34335])
rng = jax.random.PRNGKey(0)

sim = MushyGoL(k_sims=16, grid_size=256, double_step=False, colors=colors)
rollout_fn = partial(rollout_simulation, sim=sim, rollout_steps=1024*2, img_size=256, n_rollout_imgs=2048)
rollout_data = rollout_fn(rng, params)
print(rollout_data['rgb'].shape)
vid = np.array(rollout_data['rgb'])
vid = (vid*255).astype(np.uint8)
imageio.mimsave('/home/akarsh_sakana_ai/nca-alife-data/figs_final/gol_mush.mp4', vid[::2], fps=30) # for the flashing

sim = MushyGoL(k_sims=16, grid_size=512, double_step=False, colors=colors)
rollout_fn = partial(rollout_simulation, sim=sim, rollout_steps=1024*2, img_size=512, n_rollout_imgs=2048)
rollout_data = rollout_fn(rng, params)
print(rollout_data['rgb'].shape)
vid = np.array(rollout_data['rgb'])
vid = (vid*255).astype(np.uint8)
imageio.mimsave('/home/akarsh_sakana_ai/nca-alife-data/figs_final/gol_mush_large.mp4', vid[::2], fps=30) # for the flashing

sim = MushyGoL(k_sims=16, grid_size=256, double_step=True, colors=colors)
rollout_fn = partial(rollout_simulation, sim=sim, rollout_steps=1024*2, img_size=256, n_rollout_imgs=2048)
rollout_data = rollout_fn(rng, params)
print(rollout_data['rgb'].shape)
vid = np.array(rollout_data['rgb'])
vid = (vid*255).astype(np.uint8)
imageio.mimsave('/home/akarsh_sakana_ai/nca-alife-data/figs_final/gol_mush_double_step.mp4', vid, fps=60)

sim = MushyGoL(k_sims=16, grid_size=512, double_step=True, colors=colors)
rollout_fn = partial(rollout_simulation, sim=sim, rollout_steps=1024*2, img_size=512, n_rollout_imgs=2048)
rollout_data = rollout_fn(rng, params)
print(rollout_data['rgb'].shape)
vid = np.array(rollout_data['rgb'])
vid = (vid*255).astype(np.uint8)
imageio.mimsave('/home/akarsh_sakana_ai/nca-alife-data/figs_final/gol_mush_double_step_large.mp4', vid, fps=60)

2024-12-21 19:31:58.122178: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.68). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Could not load symbol cuFuncGetName. Error: /lib/x86_64-linux-gnu/libcuda.so.1: undefined symbol: cuFuncGetName


(2048, 256, 256, 3)


  self.pid = _posixsubprocess.fork_exec(


(2048, 512, 512, 3)


  self.pid = _posixsubprocess.fork_exec(


(2048, 256, 256, 3)


  self.pid = _posixsubprocess.fork_exec(


(2048, 512, 512, 3)


  self.pid = _posixsubprocess.fork_exec(
