In [None]:
from functools import partial
from typing import Tuple, Type

import jax
import jax.numpy as jnp

import functools

from qdax.baselines.genetic_algorithm import GeneticAlgorithm

from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP

from qdax.core.emitters.mutation_operators import (
    polynomial_crossover,
    polynomial_mutation,
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics

In [None]:
seed = 0

policy_hidden_layer_sizes = (64, 64)

episode_length = 100

population_size = 1000
num_iterations = 10000
proportion_mutation = 0.80
proportion_var_to_change = 0.5
proportion_to_mutate = 0.5
eta = 0.05
batch_size = 100

In [None]:
import jumanji

# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-6x6-v0')

# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)

# (Optional) Render the env state
# env.render(state)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value()          # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)

In [None]:
env.action_spec().maximum

In [None]:
jnp.ravel(timestep.observation).shape

In [None]:
jnp.prod(jnp.array(env.observation_spec().shape))

In [None]:
def observation_processing(observation):
    network_input = jnp.ravel(observation)
    return network_input


def play_step_fn(
    env_state,
    timestep,
    policy_params,
    random_key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(policy_params, network_input)

    action = jax.random.choice(
        key=random_key,
        a=num_ems * num_items,
        p=proba_action,
    )

    state_desc = None
    next_state, next_timestep, extras = env.step(env_state, action)

    # next_state_desc=next_state.info["state_descriptor"]
    next_state_desc = None

    transition = QDTransition(
        obs=timestep.observation,
        next_obs=next_timestep.observation,
        rewards=next_timestep.reward,
        dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),
        actions=action,
        # truncations=next_state.info["truncation"],
        # TODO: fix this
        truncations=jnp.array(0),
        state_desc=state_desc,
        next_state_desc=next_state_desc,
    )

    # print("Look at this transition dones: ", transition.dones)
    # print("Look at this transition rewards: ", transition.rewards)

    return next_state, next_timestep, policy_params, random_key, transition


In [None]:
# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init policy network

# interesting code from BinPackRandomAgent
num_actions = env.action_spec().maximum + 1

policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jax.nn.softmax,
)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)

# TODO: need to compute observation size from observation spec

# start by giving concat of ems and items
obs_spec = env.observation_spec()

import numpy as np
observation_size = np.prod(np.array(env.observation_spec().shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)


# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states, init_timesteps = reset_fn(keys)

In [None]:
# Prepare the scoring function
def bd_extraction_fn(data: QDTransition, mask: jnp.ndarray) -> Descriptor:
    """Compute feet contact time proportion.

    This function suppose that state descriptor is the feet contact, as it
    just computes the mean of the state descriptors given.
    """
    # reshape mask for bd extraction
    mask = jnp.expand_dims(mask, axis=-1)

    observation = jax.vmap(observation_processing)(data.obs)
    N = observation.shape[-1]

    # print("Processed observation in bd extract fn: ", observation)

    # print("Chelou : ", jnp.arange(start=0, stop=N // 2))
    # value1 = jnp.multiply(
    #     observation[..., : (N // 2)],
    #     jnp.array(1 / N) * jnp.arange(start=0, stop=N // 2),
    # )
    # value2 = jnp.multiply(
    #     observation[..., (N // 2) :],
    #     jnp.array(1 / N) * jnp.arange(start=0, stop=(N - (N // 2))),
    # )
    desc1 = jnp.mean(observation[..., : (N // 2)], axis=1, keepdims=True)
    desc2 = jnp.mean(observation[..., (N // 2) :], axis=1, keepdims=True)

    # print("descriptor 1: ", desc1)
    # print("descriptor 2: ", desc2)

    # Get behavior descriptor
    descriptors = jnp.concatenate([desc1, desc2], axis=1)

    return descriptors


# bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function,
    init_states=init_states,
    init_timesteps=init_timesteps,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

In [None]:
def scoring_fn(
    genotypes: jnp.ndarray, random_key: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
    fitnesses, _ = scoring_function(genotypes)
    return fitnesses, {}, random_key

# initial population
random_key = jax.random.PRNGKey(42)
random_key, subkey = jax.random.split(random_key)
init_genotypes = jax.random.uniform(
    subkey,
    (batch_size, genotype_dim),
    minval=minval,
    maxval=maxval,
    dtype=jnp.float32,
)

In [None]:
# crossover function
crossover_function = partial(
    polynomial_crossover, proportion_var_to_change=proportion_var_to_change
)

# mutation function
mutation_function = partial(
    polynomial_mutation,
    eta=eta,
    minval=minval,
    maxval=maxval,
    proportion_to_mutate=proportion_to_mutate,
)

# Define emitter
mixing_emitter = MixingEmitter(
    mutation_fn=mutation_function,
    variation_fn=crossover_function,
    variation_percentage=1 - proportion_mutation,
    batch_size=batch_size,
)

algo_instance = GeneticAlgorithm(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics,
)

repertoire, emitter_state, random_key = algo_instance.init(
    init_genotypes, population_size, random_key
)

# Run the algorithm
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
    algo_instance.scan_update,
    (repertoire, emitter_state, random_key),
    (),
    length=num_iterations,
)

x, y = metrics["max_fitness"][-1]
x, y