In [1]:
%load_ext autoreload
%autoreload 2
!hostname
!pwd
import os, sys
print(sys.executable)
# os.environ['CUDA_VISIBLE_DEVICES'] = "7"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
# sys.path.append(os.path.abspath(".."))

In [2]:
import os, sys, glob, pickle, copy, time
from functools import partial

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = (15, 5)
import seaborn as sns
sns.set_theme()

from tqdm.auto import tqdm
from einops import rearrange, reduce, repeat
from einop import einop

import jax
import jax.numpy as jnp
from jax.random import split

import flax
import flax.linen as nn
from flax.training.train_state import TrainState

import optax

In [3]:
from clip import CLIP
from cppn import CPPN, FlattenCPPNParameters
import util 

import experiment_utils

In [4]:
cfg_default = dict(
    seed=0,
    noun_file="/home/akarsh_sakana_ai/spaghetti/nounlist.txt",
    save_dir=None,
    prompt="an image of a {}",
    replace_only_one_niche=False,

    n_iters=100000,
    pop_size=6801,
    n_mutations=16,
    mutation="gausian",
    sigma=0.5,
)

cfgs = []

# for n_mutations in [16]:
    # for mutation in ["gaussian", "sparse"]:
        # for sigma in [0.01, 0.1, 0.5]:
for n_mutations in [16]:
    for mutation in ["gaussian"]:
        for sigma in [0.5]:
            for seed in [0, 1, 2, 3]:
                cfg = copy.deepcopy(cfg_default)
                cfg.update(seed=seed, n_mutations=n_mutations, mutation=mutation, sigma=sigma,
                        save_dir=f"/home/akarsh_sakana_ai/spaghetti-data/exp1/{seed}_{n_mutations}_{mutation}_{sigma}")
                cfgs.append(cfg)

In [5]:
experiment_utils.create_commands(cfgs, cfg_default, prefix="python map_elites.py", prune=False, out_file="./science.sh")

In [17]:
plt.figure(figsize=(20, 10))
for cfg in cfgs:
    save_dir = cfg['save_dir']
    with open('../nounlist.txt', 'r') as f:
        nouns = f.read().strip().split('\n')

    archive = util.load_pkl(save_dir, 'archive')
    data = util.load_pkl(save_dir, 'data')
    # print(jax.tree.map(lambda x: (x.shape, x.size*4/1e6), archive))
    # print(jax.tree.map(lambda x: (x.shape, x.size*4/1e6), data))

    plt.plot(data['avg_quality'])
plt.xlabel('Iteration', fontsize=20)
plt.ylabel('Average Quality', fontsize=20)
# plt.ylim(0.3, 0.32)

# Visualizing Results

In [4]:
save_dir = "/home/akarsh_sakana_ai/spaghetti-data/run_0.5/"

with open('../nounlist.txt', 'r') as f:
    nouns = f.read().strip().split('\n')

archive = util.load_pkl(save_dir, 'archive')
data = util.load_pkl(save_dir, 'data')

In [5]:

print(jax.tree.map(lambda x: (x.shape, x.size*4/1e6), archive))
print(jax.tree.map(lambda x: (x.shape, x.size*4/1e6), data))


In [6]:
plt.figure(figsize=(15, 5))
plt.subplot(121); plt.plot(data['n_transfers'].cumsum()); plt.title('n_transfers')
plt.subplot(122); plt.plot(data['avg_quality']); plt.title('avg_quality')
plt.show()

In [7]:
cppn = CPPN(n_layers=4, d_hidden=16, nonlin='tanh')
cppn = FlattenCPPNParameters(cppn)


txt = "wrestler"
idx = nouns.index(txt)

ai = jax.tree.map(lambda x: x[idx], archive)

params = ai['pheno']['params']
plt.imshow(cppn.generate_image(params))
plt.title(f"quality: {ai['quality'].item():.4f}")

In [9]:
scan_fn = lambda _, params: (None, cppn.generate_image(params))
_, imgs = jax.lax.scan(scan_fn, None, archive['pheno']['params'])
imgs = np.array(imgs)
imgs.shape

In [47]:
# plt.figure(figsize=(50, 34*1.3))
# for i in tqdm(range(6800//4//4)):
#     noun = nouns[i]
#     quality = archive['quality'][i]
#     # params = archive['pheno']['params'][i]
#     img = imgs[i]
#     plt.subplot(68, 100, i+1)
#     plt.imshow(img);plt.axis('off')
#     plt.title(f"{noun}\n{quality.item():.3f}", fontsize=2)
# plt.tight_layout()
# plt.savefig(f'{save_dir}/archive.png', dpi=300, bbox_inches='tight')
# plt.close()

poster = imgs[:6800]
poster = np.pad(poster, ((0, 0), (35, 10), (5, 5), (0, 0)), constant_values=1.)
poster = rearrange(poster, "(R C) H W D -> (R H) (C W) D", R=68, C=100)
from PIL import Image
import cv2

for i in tqdm(range(6800)):
    y, x = divmod(i, 100)
    x, y = x*(128+10), y*(128+45)
    txt = f"{nouns[i]}"
    cv2.putText(poster, txt, (x+5, y+12), cv2.FONT_HERSHEY_COMPLEX, .7, (0, 0, 0), 1)  #text,coordinate,font,size of text,color,thickness of font
    txt = f"{archive['quality'][i].item():.3f}"
    cv2.putText(poster, txt, (x+5, y+12+18), cv2.FONT_HERSHEY_COMPLEX, .5, (0, 0, 0), 1)  #text,coordinate,font,size of text,color,thickness of font

Image.fromarray((poster * 255).astype('uint8')).save(f"{save_dir}/poster.png")

In [214]:
def viz_feature_maps(features):
    max_features_per_layer = max(jax.tree.map(lambda x: x.shape[-1], features))
    n_layers = len(features)
    n_layers, max_features_per_layer

    plt.figure(figsize=(2*max_features_per_layer, 2*n_layers))
    for i, layer_features in enumerate(features):
        for j, fmap in enumerate(rearrange(layer_features, 'h w c -> c h w')):
            plt.subplot(n_layers, max_features_per_layer, i*max_features_per_layer + j + 1)
            plt.imshow(fmap); plt.xticks([]); plt.yticks([])
            if j==0:
                plt.ylabel(f"{i}", fontsize=25)
    plt.subplot(n_layers, max_features_per_layer, (n_layers-1)*max_features_per_layer + (max_features_per_layer-1) + 1)
    plt.imshow(rgb); plt.axis('off')
    plt.gcf().supylabel("Layer", fontsize=35)
    plt.gcf().supxlabel("Feature Map", fontsize=35)
    plt.suptitle("Feature Maps of CPPN", fontsize=35)
    plt.tight_layout()
    return plt.gcf()

def viz_random_mutations(cppn, params, mutation='gaussian', sigma=0.5):
    def mutate_fn(rng, params):
        if mutation == 'gaussian':
            noise = jax.random.normal(rng, params.shape)
            return params + noise * sigma
        elif mutation == 'sparse':
            rng, _rng = split(rng)
            mask = jax.random.uniform(rng, params.shape) < sigma
            noise = jax.random.normal(_rng, params.shape)
            return noise * mask + params * (1-mask)
        else:
            raise NotImplementedError
    rng = jax.random.PRNGKey(0)
    plt.figure(figsize=(20, 5))
    for i in range(45):
        rng = jax.random.PRNGKey(i)
        paramsp = mutate_fn(rng, params)
        imgp = cppn.generate_image(paramsp)
        plt.subplot(3, 15, i+1); plt.imshow(imgp); plt.axis('off')
    plt.suptitle("Random mutations of CPPN", fontsize=25)
    plt.tight_layout()
    return plt.gcf()

def viz_sweep_weights(cppn, params):
    params = jnp.array(params)
    rng = jax.random.PRNGKey(0)

    n_sweeps, granularity = 20, 7
    weight_idxs = jax.random.permutation(rng, cppn.n_params)[:n_sweeps]
    plt.figure(figsize=(1*n_sweeps, 1*granularity))
    for i, weight_idx in enumerate(weight_idxs):
        img = jnp.concatenate([cppn.generate_image(params.at[weight_idxs[i]].set(val)) for val in jnp.linspace(-3, 3, granularity)], axis=0)
        plt.subplot(1, n_sweeps, i+1)
        plt.imshow(img); plt.xticks([]); plt.yticks([])
        plt.xlabel(f"{weight_idx}", fontsize=10)
    plt.suptitle("Sweeping weights of CPPN", fontsize=20)
    plt.gcf().supxlabel("Weight Index", fontsize=20)
    plt.tight_layout()
    return plt.gcf()

In [223]:
cppn = CPPN(n_layers=4, d_hidden=16, nonlin='tanh')
cppn = FlattenCPPNParameters(cppn)
txt = "wrestler"
# txt = "buzz"
idx = nouns.index(txt)
ai = jax.tree.map(lambda x: x[idx], archive)
params = ai['pheno']['params']

In [224]:
rgb, features = cppn.generate_image(params, intermediate_features=True)
plt.imshow(rgb)
plt.title(f"quality: {ai['quality'].item():.4f}")
plt.show()

viz_feature_maps(features)
plt.show()
viz_random_mutations(cppn, params, mutation='gaussian', sigma=0.5)
plt.show()
viz_sweep_weights(cppn, params)
plt.show()

In [225]:
def loss_fn(paramsp):
    imgp = cppn.generate_image(paramsp)
    return ((rgb-imgp)**2).mean()
grad_fn = jax.jit(jax.value_and_grad(loss_fn))

@jax.jit
def train_step(state):
    loss, grad = grad_fn(state.params)
    state = state.apply_gradients(grads=grad)
    return state, loss

rng = jax.random.PRNGKey(0)
params_sgd = jax.random.normal(rng, params.shape)

tx = optax.adam(learning_rate=3e-4)
state = TrainState.create(apply_fn=None, params=params_sgd, tx=tx)

pbar = tqdm(range(100000))
for t in pbar:
    state, loss = train_step(state)
    if t%100 == 0:
        pbar.set_postfix(loss=loss.item())

params_sgd = state.params

In [226]:
rgb_sgd, features_sgd = cppn.generate_image(params_sgd, intermediate_features=True)
plt.imshow(rgb_sgd)

In [227]:
viz_feature_maps(features_sgd)
plt.show()
viz_random_mutations(cppn, params_sgd, mutation='gaussian', sigma=0.5)
plt.show()
viz_sweep_weights(cppn, params_sgd)
plt.show()

In [228]:
def viz_weight_influences(cppn, params):
    params = jnp.array(params)

    weight_sweep = jnp.linspace(-3, 3, 11)

    def get_img(weight_idx, weight_val):
        return cppn.generate_image(params.at[weight_idx].set(weight_val))
    get_img = jax.jit(jax.vmap(get_img, in_axes=(None, 0)))

    def get_influence_img(weight_idx):
        imgs = get_img(weight_idx, weight_sweep) # (n_sweeps, H, W, 3)
        return jnp.var(imgs, axis=0).mean(axis=-1)

    imgs = [get_influence_img(weight_idx) for weight_idx in tqdm(range(len(params)))]
    imgs = jnp.stack(imgs, axis=0) # (n_params, H, W, 3)

    poster = rearrange(imgs[:930], "(R C) H W ... -> (R H) (C W) ...", R=62, C=15)
    plt.figure(figsize=(20, 80))
    plt.imshow(poster); plt.xticks([]); plt.yticks([])


In [229]:
viz_weight_influences(cppn, params)
plt.show()

In [230]:
viz_weight_influences(cppn, params_sgd)
plt.show()

In [231]:
from sklearn.cluster import SpectralClustering
mat = cppn.param_reshaper.reshape_single(params)['params']['Dense_1']['kernel'].copy()
mat = jnp.array(mat)
mat = (mat>1.).astype('float32')

In [232]:
plt.imshow(mat)

In [233]:
affinity1 = mat @ mat.T
affinity2 = mat.T @ mat
plt.figure(figsize=(10, 5))
plt.subplot(121); plt.imshow(affinity1); plt.colorbar()
plt.subplot(122); plt.imshow(affinity2); plt.colorbar()

In [235]:
for n_clusters in [2, 4, 6, 8]:
    clustering1 = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', random_state=0).fit(affinity1)
    clustering2 = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', random_state=0).fit(affinity2)
    plt.imshow(mat[np.argsort(clustering1.labels_[::-1])][:, np.argsort(clustering2.labels_[::-1])])
    plt.show()

In [213]:
plt.hist(params, bins=50);
plt.hist(params_sgd, bins=50, color='red', alpha=0.5);

In [255]:
cppn = CPPN(n_layers=8, d_hidden=8, nonlin='tanh')
cppn = FlattenCPPNParameters(cppn)
plt.figure(figsize=(20, 5))
for i in range(4*20):
    p = jax.random.normal(jax.random.PRNGKey(i), (cppn.n_params,))
    img = cppn.generate_image(p)
    plt.subplot(4, 20, i+1)
    plt.imshow(img); plt.xticks([]); plt.yticks([])
plt.show()