## Training a simple genetic algorithm on the Snake environment from Jumanji

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import partial
from typing import Tuple, Type

import jax
import jax.numpy as jnp

import jumanji

import functools

from qdax.baselines.genetic_algorithm import GeneticAlgorithm

from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP

from qdax.tasks.jumanji_envs import jumanji_scoring_function

from qdax.core.emitters.mutation_operators import isoline_variation

from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics

## Define hyperparameters

In [None]:
seed = 0

policy_hidden_layer_sizes = (128, 128)

episode_length = 1000

population_size = 500
batch_size = population_size

num_iterations = 1000

iso_sigma = 0.005
line_sigma = 0.05

## Instantiate the snake environment

In [None]:
# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-6x6-v0')

# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)

# (Optional) Render the env state
# env.render(state)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value()          # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)

In [None]:
env.action_spec().num_values

In [None]:
jnp.ravel(timestep.observation).shape

In [None]:
jnp.prod(jnp.array(env.observation_spec().shape))

In [None]:
def observation_processing(observation):
    network_input = jnp.ravel(observation)
    return network_input


def play_step_fn(
    env_state,
    timestep,
    policy_params,
    random_key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(policy_params, network_input)

    action = jax.random.choice(
        key=random_key,
        a=proba_action.shape[0],
        p=proba_action,
    )

    state_desc = None
    next_state, next_timestep = env.step(env_state, action)

    # next_state_desc=next_state.info["state_descriptor"]
    next_state_desc = None

    transition = QDTransition(
        obs=timestep.observation,
        next_obs=next_timestep.observation,
        rewards=next_timestep.reward,
        dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),
        actions=action,
        # truncations=next_state.info["truncation"],
        # TODO: fix this
        truncations=jnp.array(0),
        state_desc=state_desc,
        next_state_desc=next_state_desc,
    )

    # print("Look at this transition dones: ", transition.dones)
    # print("Look at this transition rewards: ", transition.rewards)

    return next_state, next_timestep, policy_params, random_key, transition


In [None]:
# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init policy network

# interesting code from BinPackRandomAgent
num_actions = env.action_spec().maximum + 1

policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jax.nn.softmax,
)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)

# TODO: need to compute observation size from observation spec

# start by giving concat of ems and items
obs_spec = env.observation_spec()

import numpy as np
observation_size = np.prod(np.array(env.observation_spec().shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)


# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states, init_timesteps = reset_fn(keys)

In [None]:
# Prepare the scoring function
def bd_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.array) -> Descriptor:
    """Compute feet contact time proportion.

    This function suppose that state descriptor is the feet contact, as it
    just computes the mean of the state descriptors given.
    """
    # reshape mask for bd extraction
    mask = jnp.expand_dims(mask, axis=-1)

    observation = jax.vmap(observation_processing)(data.obs)
    N = observation.shape[-1]

    mean_observation = jnp.mean(observation, axis=-1)
    behavior_descriptor = jnp.tanh(mean_observation @ linear_projection.T)
    
    # Get behavior descriptor
    descriptors = jnp.concatenate([desc1, desc2], axis=1)

    return descriptors

random_key, subkey = jax.random.split(random_key)
linear_projection = jax.random.uniform(
    random_key, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)

bd_extraction_fn = functools.partial(
    bd_extraction,
    linear_projection=linear_projection
)

# bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
    jumanji_scoring_function,
    init_states=init_states,
    init_timesteps=init_timesteps,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

In [None]:
def scoring_function(
    genotypes: jnp.ndarray, random_key: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
    fitnesses, _, extra_scores, random_key = scoring_fn(genotypes, random_key)
    return fitnesses.reshape(-1, 1), extra_scores, random_key

In [None]:
# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None, 
    variation_fn=variation_fn, 
    variation_percentage=1.0, 
    batch_size=batch_size
)

algo_instance = GeneticAlgorithm(
    scoring_function=scoring_function,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics,
)

repertoire, emitter_state, random_key = algo_instance.init(
    init_variables, population_size, random_key
)

In [None]:
%%time

# Run the algorithm
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
    algo_instance.scan_update,
    (repertoire, emitter_state, random_key),
    (),
    length=num_iterations,
)

In [None]:
metrics["max_fitness"][-1]

In [None]:
repertoire

## Play snake with the best policy

In [None]:
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)

In [None]:
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n",
    f"Index in the repertoire of this individual: {best_idx}\n"
)

In [None]:
my_params = jax.tree_util.tree_map(
    lambda x: x[best_idx],
    repertoire.genotypes
)

In [None]:
# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)

for _ in range(100):
    # (Optional) Render the env state
    env.render(state)

    # Interact with the (jit-able) environment
    
    
    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(my_params, network_input)

    action = jax.random.choice(
        key=random_key,
        a=proba_action.shape[0],
        p=proba_action,
    )
    
    
    state, timestep = jax.jit(env.step)(state, action)