In [None]:
%load_ext autoreload
%autoreload 2
!hostname
!pwd
import os, sys
print(sys.executable)
# os.environ['CUDA_VISIBLE_DEVICES'] = "4"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
sys.path.append(os.path.abspath(".."))

In [None]:
import os, sys, glob, pickle
from functools import partial

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat

In [None]:
import experiment_utils
import util

# Create Experiment

In [None]:
group.add_argument("--seed", type=int, default=0)
group.add_argument("--save_dir", type=str, default=None)

group = parser.add_argument_group("model")
group.add_argument("--sim", type=str, default='boids')

group = parser.add_argument_group("data")
group.add_argument("--n_rollout_imgs", type=int, default=1)
group.add_argument("--prompts", type=str, default="an artificial cell,a bacterium")
group.add_argument("--clip_model", type=str, default="clip-vit-base-patch32") # clip-vit-base-patch32 or clip-vit-large-patch14
group.add_argument("--coef_prompts", type=float, default=1.)
group.add_argument("--coef_novelty", type=float, default=0.)

group = parser.add_argument_group("optimization")
group.add_argument("--bs", type=int, default=4)
group.add_argument("--pop_size", type=int, default=16)
group.add_argument("--n_iters", type=int, default=10000)
group.add_argument("--sigma", type=float, default=1.)

In [None]:
cfg_default = dict(
    seed=0, save_dir=None,
    sim='boids',

    n_rollout_imgs=1,
    prompts="a cell",

    coef_prompts=1., coef_novelty=0.,
    bs=1, pop_size=16,
    n_iters=1000,
    sigma=0.1,
)

seed_sweep = np.arange(1)

with open("./prompts.txt", "r") as f:
    prompt_sweep = f.read().strip().split('\n')
print(prompt_sweep)

sigma_sweep = [0.1]

# coef_novelty_sweep = [-1., -0.3, -0.1, 0., 0.1, 0.3, 1.]
coef_novelty_sweep = [0.]

sims_sweep = ['boids', 'dnca', 'lenia', 'nca_d1', 'nca_d3', 'plenia', 'plife_a', 'plife_ba', 'plife_ba_c3']
# n_iters_sweep = [1000, 1000, 1000, 1000, 1000, 1000, 400, 400, 400]
n_iters_sweep = [1000] * len(sims_sweep)

cfgs = []
for sim, n_iters in zip(sims_sweep, n_iters_sweep):
    for seed in seed_sweep:
        for sigma in sigma_sweep:
            for coef_novelty in coef_novelty_sweep:
                for iprt, prompt in enumerate(prompt_sweep):
                    print(prompt)
                    cfg = cfg_default.copy()
                    cfg.update(sim=sim, seed=seed, prompts=prompt, sigma=sigma, coef_novelty=coef_novelty, n_iters=n_iters)
                    cfg.update(save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt/{seed}_{sim}_{iprt}_{sigma}_{coef_novelty}")
                    cfgs.append(cfg)

# print(cfgs)
print(len(cfgs))

In [None]:
commands = experiment_utils.create_commands(cfgs, prefix='python main_opt.py', out_file='../science_scripts/main_opt.sh')
print('\n'.join(commands[:3]), '\n...')
print(len(commands), 'commands')

# Visualize Results

In [None]:
import jax
import jax.numpy as jnp
import copy

In [None]:
df = []

for cfg in tqdm(cfgs):
    save_dir = cfg['save_dir']
    dfi = copy.copy(cfg)
    
    data = util.load_pkl(save_dir, "data")
    dfi['best_loss'] = data['best_loss'][-50:].mean()
    # dfi.update({k: v[-50:].max(axis=-1).mean().item() for k, v in data['loss_dict'].items()})
    
    df.append(dfi)
df = pd.DataFrame(df)

In [None]:
df

In [None]:
sns.relplot(data=df, x='sim', y='best_loss', hue='prompts')
plt.xticks(rotation=90)

In [None]:
from create_sim import create_sim
import evosax
from jax.random import split

def unroll_params(rng, params, sim, img_size=64, ret='vid'):
    def step(state, _rng):
        next_state = sim.step_state(_rng, state, params)
        return next_state, state
    state_init = sim.init_state(rng, params)
    state_final, state_vid = jax.lax.scan(step, state_init, split(rng, sim.rollout_steps))
    if ret=='vid':
        vid = jax.vmap(partial(sim.render_state, params=params, img_size=img_size))(state_vid)
        return vid
    elif ret=='img':
        img = sim.render_state(state_final, params=params, img_size=img_size)
        return img

In [None]:
seed = 0
sigma = 0.1
coef_novelty = 0.
for sim_name, n_iters in zip(sims_sweep, n_iters_sweep):
    print(sim_name)
    sim = create_sim(sim_name)
    rng = jax.random.PRNGKey(0)
    param_reshaper = evosax.ParameterReshaper(sim.default_params(rng))

    plt.figure(figsize=(20, 10))
    for iprt, prompt in enumerate(tqdm(prompt_sweep)):
        save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt/{seed}_{sim_name}_{iprt}_{sigma}_{coef_novelty}"

        params, _ = util.load_pkl(save_dir, 'best')
        params = param_reshaper.reshape_single(params)
        img = unroll_params(rng, params, sim, img_size=256, ret='img')
        plt.subplot(4, 8, iprt+1)
        plt.imshow(img)
        plt.grid(False)
        plt.title(prompt)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
seed = 0
sigma = 0.1
coef_novelty = 0.
for sim_name, n_iters in zip(sims_sweep, n_iters_sweep):
    for iprt, prompt in enumerate(prompt_sweep):
        save_dir=f"/home/akarsh_sakana_ai/nca-alife-data/main_opt/{seed}_{sim_name}_{iprt}_{sigma}_{coef_novelty}"
        data = util.load_pkl(save_dir, 'data')
        color = ['r', 'g', 'b', 'y', 'm', 'c', 'k', 'purple', 'orange'][sims_sweep.index(sim_name)]
        plt.plot(data['best_loss'], color=color, alpha=0.5, label=sim_name if iprt == 0 else "")
plt.xlabel('Iterations'); plt.ylabel('Best Loss')
plt.title('Loss Curves for Different Simulations and Prompts')
plt.legend()
plt.tight_layout()
plt.show()