In [1]:
%load_ext autoreload
%autoreload 2
!hostname
!pwd
import os, sys
print(sys.executable)
# os.environ['CUDA_VISIBLE_DEVICES'] = "4"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
sys.path.append(os.path.abspath(".."))

slurm0-gpu1nodeset-0
/home/akarsh_sakana_ai/nca-alife/src/science_notebooks
/home/akarsh_sakana_ai/miniconda3/envs/nca-alife-jax/bin/python


In [2]:
import os, sys, glob, pickle
from functools import partial

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat

In [3]:
import experiment_utils
import util

# Create Experiment

In [None]:
group.add_argument("--seed", type=int, default=0)
group.add_argument("--save_dir", type=str, default=None)

group = parser.add_argument_group("model")
group.add_argument("--sim", type=str, default='boids')

group = parser.add_argument_group("data")
group.add_argument("--n_rollout_imgs", type=int, default=1)
group.add_argument("--prompts", type=str, default="an artificial cell,a bacterium")
group.add_argument("--clip_model", type=str, default="clip-vit-base-patch32") # clip-vit-base-patch32 or clip-vit-large-patch14
group.add_argument("--coef_prompts", type=float, default=1.)
group.add_argument("--coef_novelty", type=float, default=0.)

group = parser.add_argument_group("optimization")
group.add_argument("--bs", type=int, default=4)
group.add_argument("--pop_size", type=int, default=16)
group.add_argument("--n_iters", type=int, default=10000)
group.add_argument("--sigma", type=float, default=1.)

In [4]:
cfg_default = dict(
    seed=0, save_dir=None,
    sim='boids',

    n_rollout_imgs=4,
    prompts="a cell",

    coef_prompts=1., coef_novelty=0.,
    bs=1, pop_size=16,
    n_iters=1000,
    sigma=0.1,
)

seed_sweep = np.arange(1)

with open("./prompts.txt", "r") as f:
    prompt_sweep = f.read().strip().split('\n')
prompt_sweep = prompt_sweep[::4]
print(prompt_sweep)

sigma_sweep = [0.1]

# coef_novelty_sweep = [-1., -0.3, -0.1, 0., 0.1, 0.3, 1.]
coef_novelty_sweep = [0., 0.1, 0.3, 1.]

sims_sweep = ['boids', 'dnca', 'lenia', 'nca_d1', 'nca_d3', 'plenia', 'plife_a', 'plife_ba', 'plife_ba_c3']
# n_iters_sweep = [1000, 1000, 1000, 1000, 1000, 1000, 500, 500, 500]
n_iters_sweep = [1000, 1000, 1000, 1000, 1000, 1000, 300, 300, 300]

cfgs = []
for sim, n_iters in zip(sims_sweep, n_iters_sweep):
    for seed in seed_sweep:
        for sigma in sigma_sweep:
            for coef_novelty in coef_novelty_sweep:
                for iprt, prompt in enumerate(prompt_sweep):
                    print(prompt)
                    cfg = cfg_default.copy()
                    cfg.update(sim=sim, seed=seed, prompts=prompt, sigma=sigma, coef_novelty=coef_novelty, n_iters=n_iters)
                    cfg.update(save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt_oe/{seed}_{sim}_{iprt}_{sigma}_{coef_novelty}")
                    cfgs.append(cfg)

# print(cfgs)
print(len(cfgs))

['a biological cell under the microscope', 'a network of neurons', 'a red organism', 'swarm intelligence', 'a fibonacci spiral in nature', 'a caterpillar', 'a beautiful flower', 'a pepperoni pizza']
a biological cell under the microscope
a network of neurons
a red organism
swarm intelligence
a fibonacci spiral in nature
a caterpillar
a beautiful flower
a pepperoni pizza
a biological cell under the microscope
a network of neurons
a red organism
swarm intelligence
a fibonacci spiral in nature
a caterpillar
a beautiful flower
a pepperoni pizza
a biological cell under the microscope
a network of neurons
a red organism
swarm intelligence
a fibonacci spiral in nature
a caterpillar
a beautiful flower
a pepperoni pizza
a biological cell under the microscope
a network of neurons
a red organism
swarm intelligence
a fibonacci spiral in nature
a caterpillar
a beautiful flower
a pepperoni pizza
a biological cell under the microscope
a network of neurons
a red organism
swarm intelligence
a fibonacci

In [13]:
commands = experiment_utils.create_commands(cfgs, prefix='python main_opt.py', out_file='../science_scripts/main_opt_oe.sh')
print('\n'.join(commands[:3]), '\n...')
print(len(commands), 'commands')

python main_opt.py --seed=0 --save_dir="/home/akarsh_sakana_ai/nca-alife-data/main_opt_oe/0_boids_0_0.1_0.0"       --sim="boids"       --n_rollout_imgs=4 --prompts="a biological cell under the microscope" --coef_prompts=1.0 --coef_novelty=0.0 --bs=1 --pop_size=16 --n_iters=1000 --sigma=0.1
python main_opt.py --seed=0 --save_dir="/home/akarsh_sakana_ai/nca-alife-data/main_opt_oe/0_boids_1_0.1_0.0"       --sim="boids"       --n_rollout_imgs=4 --prompts="a network of neurons"                   --coef_prompts=1.0 --coef_novelty=0.0 --bs=1 --pop_size=16 --n_iters=1000 --sigma=0.1
python main_opt.py --seed=0 --save_dir="/home/akarsh_sakana_ai/nca-alife-data/main_opt_oe/0_boids_2_0.1_0.0"       --sim="boids"       --n_rollout_imgs=4 --prompts="a red organism"                         --coef_prompts=1.0 --coef_novelty=0.0 --bs=1 --pop_size=16 --n_iters=1000 --sigma=0.1 
...
288 commands


# Visualize Results

In [21]:
import jax
import jax.numpy as jnp
import copy

In [22]:
df = []

for cfg in tqdm(cfgs):
    save_dir = cfg['save_dir']
    dfi = copy.copy(cfg)
    
    data = util.load_pkl(save_dir, "data")
    dfi['best_loss'] = data['best_loss'][-50:].mean()
    # dfi.update({k: v[-50:].max(axis=-1).mean().item() for k, v in data['loss_dict'].items()})
    
    df.append(dfi)
df = pd.DataFrame(df)

  0%|          | 0/288 [00:00<?, ?it/s]

In [23]:
df

Unnamed: 0,seed,save_dir,sim,n_rollout_imgs,prompts,coef_prompts,coef_novelty,bs,pop_size,n_iters,sigma,best_loss
0,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,boids,4,a biological cell under the microscope,1.0,0.0,1,16,1000,0.1,-0.226976
1,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,boids,4,a network of neurons,1.0,0.0,1,16,1000,0.1,-0.292674
2,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,boids,4,a red organism,1.0,0.0,1,16,1000,0.1,-0.268032
3,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,boids,4,swarm intelligence,1.0,0.0,1,16,1000,0.1,-0.306701
4,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,boids,4,a fibonacci spiral in nature,1.0,0.0,1,16,1000,0.1,-0.316627
...,...,...,...,...,...,...,...,...,...,...,...,...
283,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,plife_ba_c3,4,swarm intelligence,1.0,1.0,1,16,300,0.1,0.643325
284,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,plife_ba_c3,4,a fibonacci spiral in nature,1.0,1.0,1,16,300,0.1,0.634815
285,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,plife_ba_c3,4,a caterpillar,1.0,1.0,1,16,300,0.1,0.677236
286,0,/home/akarsh_sakana_ai/nca-alife-data/main_opt...,plife_ba_c3,4,a beautiful flower,1.0,1.0,1,16,300,0.1,0.656597


In [None]:
sns.relplot(data=df, x='sim', y='best_loss', hue='prompts')
plt.xticks(rotation=90)

In [None]:
from create_sim import create_sim
import evosax
from jax.random import split

def unroll_params(rng, params, sim, img_size=64, ret='vid'):
    def step(state, _rng):
        next_state = sim.step_state(_rng, state, params)
        return next_state, state
    state_init = sim.init_state(rng, params)
    state_final, state_vid = jax.lax.scan(step, state_init, split(rng, sim.rollout_steps))
    if ret=='vid':
        vid = jax.vmap(partial(sim.render_state, params=params, img_size=img_size))(state_vid)
        return vid
    elif ret=='img':
        img = sim.render_state(state_final, params=params, img_size=img_size)
        return img

In [None]:
seed = 0
sigma = 0.1
coef_novelty = 0.
for sim_name, n_iters in zip(sims_sweep, n_iters_sweep):
    print(sim_name)
    sim = create_sim(sim_name)
    rng = jax.random.PRNGKey(0)
    param_reshaper = evosax.ParameterReshaper(sim.default_params(rng))

    plt.figure(figsize=(20, 10))
    for iprt, prompt in enumerate(tqdm(prompt_sweep)):
        save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt/{seed}_{sim_name}_{iprt}_{sigma}_{coef_novelty}"

        params, _ = util.load_pkl(save_dir, 'best')
        params = param_reshaper.reshape_single(params)
        img = unroll_params(rng, params, sim, img_size=256, ret='img')
        plt.subplot(4, 8, iprt+1)
        plt.imshow(img)
        plt.grid(False)
        plt.title(prompt)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
seed = 0
sigma = 0.1
coef_novelty = 0.
for sim_name, n_iters in zip(sims_sweep, n_iters_sweep):
    for iprt, prompt in enumerate(prompt_sweep):
        save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt/{seed}_{sim_name}_{iprt}_{sigma}_{coef_novelty}"
        data = util.load_pkl(save_dir, 'data')
        color = ['r', 'g', 'b', 'y', 'm', 'c', 'k', 'purple', 'orange'][sims_sweep.index(sim_name)]
        plt.plot(data['best_loss'], color=color, alpha=0.5, label=sim_name if iprt == 0 else "")
plt.xlabel('Iterations'); plt.ylabel('Best Loss')
plt.title('Loss Curves for Different Simulations and Prompts')
plt.legend()
plt.tight_layout()
plt.show()

In [17]:
def unroll_params(rng, params, sim, rollout_len=512, img_size=128, ret='vid'):
    def step(state, _rng):
        next_state = sim.step_state(_rng, state, params)
        return next_state, state
    state_init = sim.init_state(rng, params)
    state_final, state_vid = jax.lax.scan(step, state_init, split(rng, rollout_len))
    if ret=='vid':
        vid = jax.vmap(partial(sim.render_state, params=params, img_size=img_size))(state_vid)
        return vid
    elif ret=='img':
        img = sim.render_state(state_final, params=params, img_size=img_size)
        return img

In [19]:
from create_sim import create_sim
sim_name = None
for cfg in cfgs[:1]:
    if sim_name is None or sim_name != cfg['sim']:
        sim_name = cfg['sim']
        sim = create_sim(cfg['sim'])
        unroll_fn = jax.jit(partial(unroll_params, sim=sim, rollout_len=sim.rollout_steps, img_size=128, ret='vid'))
    
    # unroll_fn(cfg[])

(array([-1.3907120e-01, -5.5725068e-02, -5.3500608e-03, -7.7292234e-01,
        -1.0215148e-01,  4.3550035e-01, -1.2078172e-01, -8.1291562e-01,
         1.9199310e-01,  1.1110213e+00,  7.0812583e-01,  1.0587292e+00,
        -1.7768556e-01,  7.5250983e-01, -1.0421708e+00, -5.9190845e-01,
         9.9746460e-01, -1.3373494e+00, -2.4140354e-01,  5.7246381e-01,
        -6.9753110e-01, -3.6157879e-01,  1.2048494e+00, -6.8600081e-02,
         5.5970424e-01,  7.9800159e-02,  1.1112787e+00, -1.1427866e-01,
         3.2938723e-02,  2.9171443e-01,  1.0095203e+00,  5.3387374e-01,
        -3.4883108e-02, -7.1899164e-01, -6.7817217e-01, -6.9385332e-01,
         7.6483667e-01,  1.6806908e-02,  2.0646590e-01,  1.9232227e-01,
         8.7041181e-01,  2.6317567e-01,  3.3361229e-01,  3.3676930e-02,
        -6.2710160e-01,  1.2156569e-01, -1.3024729e+00, -4.6158814e-01,
         4.8311853e-01, -3.6351919e-01, -1.4804748e+00,  5.1957268e-01,
         3.1182304e-01,  1.1822164e+00, -8.7982261e-01,  8.20923