In [1]:
from run_lla import run_lla_gymnax
import jax.random as jrd
import jax.numpy as jnp
from jax import jit, vmap, default_backend

In [2]:
N = 200

In [3]:
run_lla_gymnax(
    jrd.PRNGKey(0),
    # run_name="arxaqapi/Cartpole Soft/8of12aj9",
    run_name="arxaqapi/Cartpole Soft/3ugflqlq",
    initial_genome_name="genomes/g0_mean_indiv.npy",
    final_genome_name="genomes/g100_mean_indiv.npy",
    title="LLA of PA",
    n=N,
)

In [4]:
from jax import lax, jit, vmap
import jax.numpy as jnp
import jax.random as jrd
import wandb

from gene.lla import load_genomes, interpolate_2D, plot_ll
from gene.evaluate import evaluate_individual_soft

from functools import partial


def run_lla_gymnax_pa(
    rng,
    run_name: str = "arxaqapi/Cartpole/qvobnkry",  # direct | seed 9
    initial_genome_name: str = "genomes/1685094639_g0_mean_indiv.npy",
    final_genome_name: str = "genomes/1685094639_g100_mean_indiv.npy",
    title: str = "",
):
    rng, interpolation_rng, eval_rng, rng_action_sampling = jrd.split(rng, 4)

    # NOTE - 1. download files from run
    api = wandb.Api()
    run = api.run(run_name)
    config = run.config

    path_initial = run.file(initial_genome_name).download(replace=True).name
    path_final = run.file(final_genome_name).download(replace=True).name

    # NOTE - 2. load files
    initial_genome, final_genome = load_genomes(path_initial, path_final)

    # NOTE - 3. interpolate
    genomes, xs, ys = interpolate_2D(
        initial_genome, final_genome, n=N, key=interpolation_rng
    )

    # NOTE - 4. evaluate at each interpolation step
    #
    part_eval = partial(
        evaluate_individual_soft,
        config=config,
        rng=eval_rng,
        rng_action_sampling=rng_action_sampling,
    )
    vmap_eval = jit(vmap(part_eval, in_axes=(0)))

    values = vmap_eval(genomes)

    # NOTE - 5. plot landscape
    plot_ll(
        values,
        xs,
        ys,
        evaluate_individual_soft(initial_genome, eval_rng, rng_action_sampling, config),
        evaluate_individual_soft(final_genome, eval_rng, rng_action_sampling, config),
        title=title,
    )

In [5]:
run_lla_gymnax_pa(
    jrd.PRNGKey(0),
    run_name="arxaqapi/Cartpole Soft/3ugflqlq",
    initial_genome_name="genomes/g0_mean_indiv.npy",
    final_genome_name="genomes/g100_mean_indiv.npy",
    title="LLA of PA",
)