<a href="https://colab.research.google.com/github/SakanaAI/asal/blob/main/asal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automating the Search for Artificial Life with Foundation Models

Automated Search for Artificial Life (ASAL) is an algorithm to automatically discover interesting ALife simulations!

Rather than hand designing a particular simulation, just parameterize a big set of simulations (e.g. all Lenia simulations) and automatically search over them.

This notebook will show you how to 
- Run ASAL
    - Search for supervised target simulations specified by a single prompt
    - Search for supervised target simulations specified by multiple temporal prompts
    - Search for open-ended simulations
    - Search for supervised target + open-ended simulations
    - Illuminate the entire substrate by finding a diverse set of simulations
- Give some tips on how to get the search to work for whatever you want to discover.
- Show you how to play around with other substrates we implemented.
- Load in previously discovered lifeforms from our large scale searches

NOTE: If you are in Google Colab, make sure to change the runtime to have a GPU!

This notebook was run on Sakana AI's cluster.
Because of the nondeterminism of the jax compilation on different hardware and the chaotic nature of the simulations, your result will probably be different than these exact pictures.

In [2]:
%load_ext autoreload
%autoreload 2
import os, sys
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [3]:
import os, sys, glob, pickle
from functools import partial

import jax
import jax.numpy as jnp
from jax.random import split
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat

import substrates
import foundation_models
from rollout import rollout_simulation
import asal_metrics
import util


## Illumination

Great! Now let's try illuminating the entire Lenia substrate and find a diverse set of Lenia creatures!

Run this command in the terminal:

In [5]:
! python main_illuminate.py --seed=0 --save_dir="./data/illuminate_1" --maisubstrate='plenia' --n_child=32 --pop_size=256 --n_iters=1000 --sigma=0.1



In [6]:
save_dir = "./data/illuminate_1"
data = util.load_pkl(save_dir, "data") # load optimization data
pop = util.load_pkl(save_dir, "pop") # load the best parameters found
params = pop['params']

# fm = foundation_models.create_foundation_model('clip')
substrate = substrates.create_substrate('plenia')
substrate = substrates.FlattenSubstrateParameters(substrate)

rollout_fn = partial(rollout_simulation, s0=None, substrate=substrate, fm=None, rollout_steps=substrate.rollout_steps, time_sampling='final', img_size=224, return_state=False)
rollout_fn = jax.jit(rollout_fn)

rng = jax.random.PRNGKey(0)

rollout_data = [rollout_fn(rng, p) for p in params] # rollout the simulation using this rng seed and simulation parameters
rollout_data = jax.tree.map(lambda *x: jnp.stack(x, axis=0), *rollout_data) # stack the list of rollout data into a single tree
# note if you have enough GPU memory, you can just do vmap on rollout_fn!



In [8]:
plt.figure(figsize=(20, 5))
plt.plot(data['loss'])
plt.xlabel("Iterations", fontsize=20); plt.ylabel("Loss", fontsize=20); plt.title("Optimization", fontsize=25)
plt.xticks(fontsize=15); plt.yticks(fontsize=15)
plt.show()

plt.figure(figsize=(20, 6))
img = np.array(rollout_data['rgb'])
img = np.pad(img, ((0, 0), (2, 2), (2, 2), (0, 0)), constant_values=0.5)
img = rearrange(img, "(R C) H W D -> (R H) (C W) D", R=8)
img = np.pad(img, ((2, 2), (2, 2), (0, 0)), constant_values=0.5)
plt.imshow(img)
plt.xticks([], fontsize=15); plt.yticks([], fontsize=15)
plt.title("Visualizing Final State of All Simulations", fontsize=25)
plt.show()





In [None]:
# UMAPのインストール（初回のみ必要）
!pip install umap-learn matplotlib

In [None]:
# 必要なライブラリのインポート
import umap
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np

# UMAPで次元削減
reducer = umap.UMAP(n_components=2, 
                    n_neighbors=15,  # 近傍数（小さいほど局所構造を保持）
                    min_dist=0.1,    # 点間の最小距離（小さいほど密集）
                    random_state=42)

# pop['z']はCLIP埋め込みベクトル（高次元）
z_2d = reducer.fit_transform(np.array(pop['z']))

# 大きなキャンバスを作成
plt.figure(figsize=(20, 20))

# 散布図をプロット（背景として点を表示）
plt.scatter(z_2d[:, 0], z_2d[:, 1], alpha=0.5, c='gray')

# 各点に対応する画像をプロット
rgb_images = np.array(rollout_data['rgb'])
for i in range(len(z_2d)):
    # 画像のサイズをスケール調整
    img = OffsetImage(rgb_images[i], zoom=0.5)
    # 画像を対応する座標に配置
    ab = AnnotationBbox(img, (z_2d[i, 0], z_2d[i, 1]), frameon=False)
    plt.gca().add_artist(ab)

# 軸の設定
plt.title("UMAP Visualization of CLIP Embedding Space", fontsize=25)
plt.xlabel("UMAP Dimension 1", fontsize=20)
plt.ylabel("UMAP Dimension 2", fontsize=20)
plt.grid(True, linestyle='--', alpha=0.7)

# プロット表示
plt.tight_layout()
plt.show()









In [11]:
# 必要なライブラリをインポート
import umap
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from scipy.ndimage import zoom

# パラメータ設定
grid_size = 30  # 格子の大きさ
random_seed = 42

# UMAP次元削減を実行
reducer = umap.UMAP(n_components=2, 
                    n_neighbors=15,
                    min_dist=0.1,
                    random_state=random_seed)
                    
# pop['z']からCLIP埋め込みを取得して次元削減
z_2d = reducer.fit_transform(np.array(pop['z']))

# スケーリング（UMAP出力を0〜grid_size-1の範囲に変換）
z_2d_scaled = np.zeros_like(z_2d)
for dim in range(2):
    min_val, max_val = z_2d[:, dim].min(), z_2d[:, dim].max()
    z_2d_scaled[:, dim] = (z_2d[:, dim] - min_val) / (max_val - min_val) * (grid_size - 1)

# 格子を初期化（-1は空のセルを示す）
grid = np.full((grid_size, grid_size), -1, dtype=int)
used_positions = set()

# 各パターンを格子に配置（衝突回避メカニズム付き）
for idx in range(len(z_2d)):
    # 初期位置を計算
    x, y = int(z_2d_scaled[idx, 0]), int(z_2d_scaled[idx, 1])
    
    # 既に埋まっている場合、近くの空きセルを探索
    if (x, y) in used_positions:
        # 螺旋状に探索（最大10ステップ）
        found = False
        for r in range(1, 11):
            for dx in range(-r, r+1):
                for dy in range(-r, r+1):
                    if abs(dx) == r or abs(dy) == r:  # 境界上のみチェック
                        nx, ny = x + dx, y + dy
                        if 0 <= nx < grid_size and 0 <= ny < grid_size and (nx, ny) not in used_positions:
                            x, y = nx, ny
                            found = True
                            break
                if found:
                    break
            if found:
                break
    
    # 有効な位置にパターンを配置
    if 0 <= x < grid_size and 0 <= y < grid_size:
        grid[y, x] = idx  # 注意: yとxが反転（プロット座標系）
        used_positions.add((x, y))

# 大きなプロット用のキャンバスを準備
plt.figure(figsize=(30, 30), facecolor='black')

# 色付きの背景グラデーションを作成
cmap = LinearSegmentedColormap.from_list('custom', ['#000033', '#000022', '#000011', '#000000'], N=256)

# 各セルにパターンを配置
cell_size = 100  # 各セルのピクセルサイズ
full_img = np.zeros((grid_size * cell_size, grid_size * cell_size, 3))

for y in range(grid_size):
    for x in range(grid_size):
        idx = grid[y, x]
        if idx != -1:
            # パターン画像を取得
            pattern = np.array(rollout_data['rgb'][idx])
            
            # 画像をリサイズ
            pattern_resized = zoom(pattern, (cell_size/pattern.shape[0], cell_size/pattern.shape[1], 1), order=1)
            
            # 全体画像に配置
            full_img[y*cell_size:(y+1)*cell_size, x*cell_size:(x+1)*cell_size] = pattern_resized

# 左上に「Random」セクションを追加（オプション）
random_size = grid_size // 5
for y in range(random_size):
    for x in range(random_size):
        if np.random.random() > 0.7:  # 30%の確率でランダムパターンを表示
            random_idx = np.random.randint(0, len(rollout_data['rgb']))
            pattern = np.array(rollout_data['rgb'][random_idx])
            pattern_resized = zoom(pattern, (cell_size/pattern.shape[0], cell_size/pattern.shape[1], 1), order=1)
            full_img[y*cell_size:(y+1)*cell_size, x*cell_size:(x+1)*cell_size] = pattern_resized

# グリッド線を追加
for i in range(grid_size + 1):
    full_img[i*cell_size-1:i*cell_size+1, :] = [0.2, 0.2, 0.2]
    full_img[:, i*cell_size-1:i*cell_size+1] = [0.2, 0.2, 0.2]

# 左上の「Random」セクションに白い枠を追加
box_y, box_x = 0, 0
box_h, box_w = random_size * cell_size, random_size * cell_size
line_width = 3
full_img[box_y:box_y+line_width, box_x:box_x+box_w] = 1  # 上辺
full_img[box_y+box_h-line_width:box_y+box_h, box_x:box_x+box_w] = 1  # 下辺
full_img[box_y:box_y+box_h, box_x:box_x+line_width] = 1  # 左辺
full_img[box_y:box_y+box_h, box_x+box_w-line_width:box_x+box_w] = 1  # 右辺

# "Random"ラベルをつける
plt.text(box_x + box_w/2, box_y + cell_size/2, "Random", 
         color='white', fontsize=30, ha='center', va='center')

# 最終的な可視化を表示
plt.imshow(full_img)
plt.axis('off')
plt.title('Simulation Atlas of Plenia', fontsize=40, color='white', pad=20)

# 軸ラベルを追加
plt.text(grid_size * cell_size / 2, grid_size * cell_size - 40, 
         'CLIP UMAP 1', fontsize=30, color='white', ha='center')
plt.text(40, grid_size * cell_size / 2, 
         'CLIP UMAP 2', fontsize=30, color='white', va='center', rotation=90)

plt.tight_layout()
plt.savefig('plenia_atlas.png', dpi=150, bbox_inches='tight', facecolor='black')
plt.show()





In [13]:
# 必要なライブラリをインポート
import umap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import sklearn.decomposition as decomposition

# パラメータ空間の次元削減と可視化
def visualize_parameter_space(params, rollout_data, method='umap', n_samples=None):
    """
    パラメータ空間でのパターン分布を可視化
    
    params: シミュレーションパラメータのリスト
    rollout_data: シミュレーション結果のrgb画像データ
    method: 'umap'または'pca'
    n_samples: 表示するサンプル数（Noneの場合はすべて表示）
    """
    # パラメータを配列に変換
    param_array = np.array([p.flatten() for p in params])
    
    # サンプル数が指定されている場合はサブセットを選択
    if n_samples is not None and n_samples < len(param_array):
        indices = np.random.choice(len(param_array), n_samples, replace=False)
        param_subset = param_array[indices]
        rgb_subset = np.array(rollout_data['rgb'])[indices]
    else:
        param_subset = param_array
        rgb_subset = np.array(rollout_data['rgb'])
        indices = np.arange(len(param_array))
    
    # 次元削減を実行
    if method == 'umap':
        reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42)
        embedding = reducer.fit_transform(param_subset)
    else:  # PCAを使用
        pca = decomposition.PCA(n_components=2)
        embedding = pca.fit_transform(param_subset)
    
    # 基本的な散布図
    plt.figure(figsize=(20, 16))
    scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=indices, cmap='viridis', 
                          alpha=0.7, s=100)
    
    # サムネイル画像を表示（サンプル数が少ない場合のみ）
    if len(param_subset) <= 100:
        for i in range(len(embedding)):
            img = rgb_subset[i]
            imagebox = OffsetImage(img, zoom=0.3)
            ab = AnnotationBbox(imagebox, (embedding[i, 0], embedding[i, 1]), 
                                frameon=True, pad=0.2)
            plt.gca().add_artist(ab)
    
    # パラメータ分布の密度を示すカラーバー
    plt.colorbar(scatter, label='Pattern Index')
    
    # グラフの詳細を設定
    plt.title(f'Pattern Distribution in Parameter Space ({method.upper()})', fontsize=24)
    plt.xlabel(f'{method.upper()} Component 1', fontsize=18)
    plt.ylabel(f'{method.upper()} Component 2', fontsize=18)
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # パラメータデータの特徴を表示
    param_dim = param_array.shape[1]
    plt.figtext(0.02, 0.02, f"Parameter Space Dimensionality: {param_dim}", fontsize=14)
    
    return embedding

# パラメータ空間のUMAP可視化
umap_params = visualize_parameter_space(params, rollout_data, method='umap')
plt.tight_layout()
plt.savefig('parameter_space_umap.png', dpi=150)
plt.show()

# パラメータ空間のPCA可視化
pca_params = visualize_parameter_space(params, rollout_data, method='pca')
plt.tight_layout()
plt.savefig('parameter_space_pca.png', dpi=150)
plt.show()

# パラメータカテゴリ別ヒートマップ（パラメータをタイプ別に分析）
def parameter_heatmap(params):
    """パラメータの相関関係をヒートマップで表示"""
    param_array = np.array([p.flatten() for p in params])
    
    # パラメータが多すぎる場合は最初の20個だけ表示
    if param_array.shape[1] > 20:
        param_subset = param_array[:, :20]
    else:
        param_subset = param_array
    
    plt.figure(figsize=(16, 14))
    sns.heatmap(np.corrcoef(param_subset, rowvar=False), 
                cmap='coolwarm', center=0, 
                xticklabels=np.arange(param_subset.shape[1]),
                yticklabels=np.arange(param_subset.shape[1]))
    plt.title('Parameter Correlation Matrix', fontsize=20)
    plt.tight_layout()
    plt.show()

# CLIP空間とパラメータ空間の関係（両方の埋め込みの相関）
def compare_spaces(clip_embeddings, param_embeddings):
    """CLIP空間とパラメータ空間の関係を可視化"""
    from scipy.stats import spearmanr
    
    # 各次元間の相関を計算
    corr, pval = spearmanr(clip_embeddings, param_embeddings)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(clip_embeddings[:, 0], param_embeddings[:, 0], alpha=0.7)
    plt.title('CLIP Space vs Parameter Space (First Components)', fontsize=18)
    plt.xlabel('CLIP UMAP Component 1', fontsize=14)
    plt.ylabel('Parameter UMAP Component 1', fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # 相関係数を表示
    correlation = np.corrcoef(clip_embeddings[:, 0], param_embeddings[:, 0])[0, 1]
    plt.figtext(0.02, 0.02, f"Correlation: {correlation:.3f}", fontsize=14)
    
    plt.tight_layout()
    plt.show()

# CLIPとパラメータ空間の比較のために、まずCLIPの埋め込みを計算
reducer_clip = umap.UMAP(n_components=2, random_state=42)
clip_embeddings = reducer_clip.fit_transform(np.array(pop['z']))

# 両空間を比較
compare_spaces(clip_embeddings, umap_params)

# パラメータの分布特性を確認
parameter_heatmap(params)











