In [1]:
%load_ext autoreload
%autoreload 2 

In [2]:
import jax 
import jax.numpy as jnp 
import chex
import optax 

from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from craftax.craftax.renderer import render_craftax_pixels as render_pixels
from craftax.craftax.world_gen.world_gen import generate_world as generate_world_craftax

from examples.craftax.craftax_plr import ActorCritic, TrainState, compute_ued_score, compute_max_returns
from examples.craftax.craftax_wrappers import LogWrapper

from craftax.craftax.constants import Achievement
from jaxued.wrappers import AutoReplayWrapper
from jaxued.level_sampler import LevelSampler 

from examples.craftax.craftax_plr import sample_trajectories_and_learn

Loading textures from cache.
Textures successfully loaded from cache.


# Setup 
*** 

In [3]:
config = {
    'num_train_envs':4, 
    "level_buffer_capacity":100,
    "replay_prob":0.8,
    "staleness_coeff":0.3,
    "minimum_fill_ratio":0.1,
    "prioritization":"rank",
    "temperature":1.0,
    "topk_k":0.3,
    "max_grad_norm":1,
    "buffer_duplicate_check":True,
    "exploratory_grad_updates":False,
    "outer_rollout_steps":64,
    "num_steps":64,
    "gamma":0.99,
    "gae_lambda":0.9,
    "num_minibatches":2,
    "epoch_ppo":5,
    "clip_eps":0.2,
    "entropy_coeff":0.01,
    "critic_coeff":0.2,
    "num_updates":10,
    "lr":3e-04,
    "score_function":"MaxMC",
    "eval_freq":5,
    "use_accel":True,
    "num_edits":1,
}

In [4]:
ENV_CLASS = CraftaxSymbolicEnv
generate_world = generate_world_craftax
render_craftax_pixels = render_pixels

In [5]:
DEFAULT_STATICS = ENV_CLASS.default_static_params()
default_env = ENV_CLASS(DEFAULT_STATICS)
env = LogWrapper(default_env)
env = AutoReplayWrapper(env)
env_params = env.default_params

In [6]:
type(env), type(env._env), type(env._env._env)

(jaxued.wrappers.autoreplay.AutoReplayWrapper,
 examples.craftax.craftax_wrappers.LogWrapper,
 craftax.craftax.envs.craftax_symbolic_env.CraftaxSymbolicEnv)

In [7]:
def sample_random_level(rng):
    return generate_world(rng, env.default_params, DEFAULT_STATICS)

In [8]:
rng = jax.random.PRNGKey(1)

rng, rng_levels, rng_reset = jax.random.split(rng, 3)
new_levels = jax.vmap(sample_random_level)(
    jax.random.split(rng_levels, config["num_train_envs"])
)
print(type(new_levels))


<class 'craftax.craftax.craftax_state.EnvState'>


In [9]:
init_obs, init_env_state = jax.vmap(
    env.reset_to_level, in_axes=(0, 0, None)
)(
    jax.random.split(rng_reset, config["num_train_envs"]),
    new_levels,
    env_params,
)

In [10]:
print(init_obs.shape)

(4, 8268)


In [11]:
{'1':jnp.array([1])}

{'1': Array([1], dtype=int32)}

## Utils 


In [12]:
level_sampler = LevelSampler(
    capacity=config["level_buffer_capacity"],
    replay_prob=config["replay_prob"],
    staleness_coeff=config["staleness_coeff"],
    minimum_fill_ratio=config["minimum_fill_ratio"],
    prioritization=config["prioritization"],
    prioritization_params={
        "temperature": config["temperature"],
        "k": config["topk_k"],
    },
    duplicate_check=config["buffer_duplicate_check"],
)


@jax.jit
def create_train_state(rng) -> TrainState:
    # Creates the train state
    def linear_schedule(count):
        frac = 1.0 - (
            count // (config["num_minibatches"] * config["epoch_ppo"])
        ) / (config["num_updates"] * config["outer_rollout_steps"])
        return config["lr"] * frac

    obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params)
    obs = jax.tree_util.tree_map(
        lambda x: jnp.repeat(
            x[None, ...], 
            config["num_train_envs"], 
            axis=0
        )[None, ...],
        obs,
    )
    init_x = (
        obs, 
        jnp.zeros(
            (1, config["num_train_envs"],), 
            dtype=jnp.bool
        )
    )
    #print(f"init: {init_x[0].shape}, {init_x[1].shape}")
    network = ActorCritic(env.action_space(env_params).n)
    network_params = network.init(
        rng, 
        init_x,
        ActorCritic.initialize_carry((config["num_train_envs"],))
    )
    tx = optax.chain(
        optax.clip_by_global_norm(config["max_grad_norm"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    pholder_level = sample_random_level(jax.random.PRNGKey(0))
    sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf})
    pholder_level_batch = jax.tree_util.tree_map(
        lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0),
        pholder_level,
    )
    return TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
        sampler=sampler,
        update_state=0,
        num_dr_updates=0,
        num_replay_updates=0,
        num_mutation_updates=0,
        dr_last_level_batch=pholder_level_batch,
        replay_last_level_batch=pholder_level_batch,
        mutation_last_level_batch=pholder_level_batch,
    )


from typing import Tuple 
from examples.craftax.craftax_plr import UpdateState 
from examples.craftax.mutators import make_mutator_craftax_claude_35_easy_hard
mutate_level = make_mutator_craftax_claude_35_easy_hard()

def train_step(carry: Tuple[chex.PRNGKey, TrainState], _):
        """
        This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step.
        """

        def on_new_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores.
            The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True.
            """
            sampler = train_state.sampler

            # Reset
            rng, rng_levels, rng_reset = jax.random.split(rng, 3)
            new_levels = jax.vmap(sample_random_level)(
                jax.random.split(rng_levels, config["num_train_envs"])
            )
            init_obs, init_env_state = jax.vmap(
                env.reset_to_level, in_axes=(0, 0, None)
            )(
                jax.random.split(rng_reset, config["num_train_envs"]),
                new_levels,
                env_params,
            )
            #print(f"#### TRAIN init_obs: {init_obs.shape}")
            # Rollout
            (
                (rng, train_state, _, _, _),
                (
                    _,
                    _,
                    rewards,
                    dones,
                    _,
                    values,
                    info,
                    advantages,
                    _,
                    losses,
                    grad_norms,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                ActorCritic.initialize_carry((config["num_train_envs"],)),
                init_obs,
                init_env_state,
                update_grad=config["exploratory_grad_updates"],
            )
            # dones (total_steps, num_train_envs)
            # rewards (total_steps, num_train_envs)
            max_returns = compute_max_returns(dones, rewards)
            scores = compute_ued_score(config, dones, values, max_returns, advantages)
            sampler, _ = level_sampler.insert_batch(
                sampler, new_levels, scores, {"max_return": max_returns}
            )
            
            achievement_per_done_exp = ((info["achievements"] * dones[..., None]).sum(axis=0).sum(axis=0))/ dones.sum()
            metrics = {
                "losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
                "achievements": achievement_per_done_exp,
                "achievement_count": (info["achievement_count"] * dones).sum()/ dones.sum(),
                "mean_returned_episode_length": (info["returned_episode_lengths"] * dones).sum()/ dones.sum(),
                "max_returned_episode_length": info["returned_episode_lengths"].max(),
                "levels_played": init_env_state.env_state,
                "mean_returns": (info["returned_episode_returns"] * dones).sum()/ dones.sum(),
                "grad_norms": grad_norms.mean(),
                "scores": scores,
                "stage":0,
            }

            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.DR,
                num_dr_updates=train_state.num_dr_updates + 1,
                dr_last_level_batch=new_levels,
            )
            return (rng, train_state), metrics

        def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            This samples levels from the level buffer, and updates the policy on them.
            """
            sampler = train_state.sampler

            # Collect trajectories on replay levels
            rng, rng_levels, rng_reset = jax.random.split(rng, 3)
            sampler, (level_inds, levels) = level_sampler.sample_replay_levels(
                sampler, rng_levels, config["num_train_envs"]
            )
            init_obs, init_env_state = jax.vmap(
                env.reset_to_level, in_axes=(0, 0, None)
            )(jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params)
            # Rollout
            (
                (rng, train_state, _, _, _),
                (
                    _,
                    _,
                    rewards,
                    dones,
                    _,
                    values,
                    info,
                    advantages,
                    _,
                    losses,
                    grad_norms,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                ActorCritic.initialize_carry((config["num_train_envs"],)),
                init_obs,
                init_env_state,
                update_grad=True,
            )

            max_returns = jnp.maximum(
                level_sampler.get_levels_extra(sampler, level_inds)["max_return"],
                compute_max_returns(dones, rewards),
            )
            scores = compute_ued_score(config, dones, values, max_returns, advantages) #(num_train_envs,)
            sampler = level_sampler.update_batch(
                sampler, level_inds, scores, {"max_return": max_returns}
            )

            achievement_per_done_exp = ((info["achievements"] * dones[..., None]).sum(axis=0).sum(axis=0))/ dones.sum()
            metrics = {
                "losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),  # this is a tuple of losses
                "achievements": achievement_per_done_exp, # this is not a scalar    
                "achievement_count": (info["achievement_count"] * dones).sum()/ dones.sum(),
                "mean_returned_episode_length": (info["returned_episode_lengths"] * dones).sum()/ dones.sum(),
                "max_returned_episode_length": info["returned_episode_lengths"].max(),
                "levels_played": init_env_state.env_state,
                "mean_returns": (info["returned_episode_returns"] * dones).sum()/ dones.sum(),
                "grad_norms": grad_norms.mean(),
                "scores": scores,
                "stage":1,
            }

            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.REPLAY,
                num_replay_updates=train_state.num_replay_updates + 1,
                replay_last_level_batch=levels,
            )
            return (rng, train_state), metrics

        def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            This mutates the previous batch of replay levels and potentially adds them to the level buffer.
            This also updates the policy iff `config["exploratory_grad_updates"]` is True.
            """
            sampler = train_state.sampler
            rng, rng_mutate, rng_reset = jax.random.split(rng, 3)

            # mutate
            parent_levels = train_state.replay_last_level_batch
            child_levels = jax.vmap(mutate_level, (0, 0, None))(
                jax.random.split(rng_mutate, config["num_train_envs"]),
                parent_levels,
                config["num_edits"],
            )
            init_obs, init_env_state = jax.vmap(
                env.reset_to_level, in_axes=(0, 0, None)
            )(
                jax.random.split(rng_reset, config["num_train_envs"]),
                child_levels,
                env_params,
            )

            # Rollout
            (
                (rng, train_state, _, _, _),
                (
                    _,
                    _,
                    rewards,
                    dones,
                    _,
                    values,
                    info,
                    advantages,
                    _,
                    losses,
                    grad_norms,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                ActorCritic.initialize_carry((config["num_train_envs"],)),
                init_obs,
                init_env_state,
                update_grad=config["exploratory_grad_updates"],
            )

            max_returns = compute_max_returns(dones, rewards)
            scores = compute_ued_score(config, dones, values, max_returns, advantages)
            sampler, _ = level_sampler.insert_batch(
                sampler, child_levels, scores, {"max_return": max_returns}
            )

            achievement_per_done_exp = ((info["achievements"] * dones[..., None]).sum(axis=0).sum(axis=0))/ dones.sum()
            metrics = {
                "losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
                "achievements": achievement_per_done_exp,
                "achievement_count": (info["achievement_count"] * dones).sum()/ dones.sum(),
                "mean_returned_episode_length": (info["returned_episode_lengths"] * dones).sum() / dones.sum(),
                "max_returned_episode_length": info["returned_episode_lengths"].max(),
                "levels_played": init_env_state.env_state,
                "mean_returns": (info["returned_episode_returns"] * dones).sum()/ dones.sum(),
                "grad_norms": grad_norms.mean(),
                "scores": scores,
                "stage":2,
            }

            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.DR,
                num_mutation_updates=train_state.num_mutation_updates + 1,
                mutation_last_level_batch=child_levels,
            )
            return (rng, train_state), metrics

        rng, train_state = carry
        rng, rng_replay = jax.random.split(rng)

        # The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate.
        # on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`).
        branches = [
            on_new_levels,
            on_replay_levels,
        ]
        if config["use_accel"]:
            s = train_state.update_state
            branch = (1 - s) * level_sampler.sample_replay_decision(
                train_state.sampler, rng_replay
            ) + 2 * s
            branches.append(on_mutate_levels)
        else:
            branch = level_sampler.sample_replay_decision(
                train_state.sampler, rng_replay
            ).astype(int)

        return jax.lax.switch(branch, branches, rng, train_state)

## Testing 

In [13]:
train_state: TrainState = create_train_state(rng)
rng, subrng = jax.random.split(rng)

In [14]:
print(init_obs.shape)

(4, 8268)


In [15]:
rng_init, rng_train = jax.random.split(rng)

runner_state = (rng_train, train_state)
(rng, train_state), metrics = jax.lax.scan(
    train_step, runner_state, None, config["eval_freq"]
)

In [16]:
#metrics

In [17]:
import numpy as np

def log_train_eval(stats:chex.ArrayTree, train_state_info:chex.ArrayTree):
        """stats is nothing but the metrics from the eval and train loop

        Args:
            stats (chex.ArrayTree): _description_
            train_state_info (chex.ArrayTree): _description_
        """
        #print(f"Logging update: {stats['update_count']}")
        stats.pop("levels_played", None)

        # generic stats
        # env_steps = (
        #     stats["update_count"]
        #     * config["num_train_envs"]
        #     * config["num_steps"]
        #     * config["outer_rollout_steps"]
        # )
        # env_steps_delta = (
        #     config["eval_freq"]
        #     * config["num_train_envs"]
        #     * config["num_steps"]
        #     * config["outer_rollout_steps"]
        # )
        log_dict = {}
        def _get_stage_train_metrics(stats_:chex.ArrayTree, stage_:int) -> chex.ArrayTree:
            print(f"stage_{stage_}")
            prefix = {
                0: "gen",
                1: "replay",
                2: "mutation",
            }[stage_]

            stage_selector = stats_["stage"] == stage_
            if stage_selector.sum() == 0:
                return {}
            stage_stats = {}
            for k, v in stats.items():
                if k in {
                    "achievements",
                    "achievement_count",
                    "grad_norms",
                    "mean_returned_episode_lengths",
                    "max_returned_episode_length",
                    "mean_returns",
                    "scores",
                }:
                    print(k, v.shape)
                    if "max" in k:
                        stage_stats[f"{prefix}/{k}"] = v[stage_selector].max(axis=0)
                    elif k == "scores":
                        stage_stats[f"{prefix}/{k}"] = v[stage_selector].mean()
                    else:
                        stage_stats[f"{prefix}/{k}"] = v[stage_selector].mean(axis=0)

            stage_stats.update(
                jax.tree_util.tree_map(
                    lambda idx: stage_stats[f"{prefix}/achievements"].at[idx].get(),
                    {
                        f"achievements_{ac.name.lower()}" : ac.value \
                        for ac in Achievement
                    }
                )
            )
            del stage_stats[f"{prefix}/achievements"]
            del stage_selector

            return stage_stats

        # train performance 
        #### random 
        log_dict.update(
            _get_stage_train_metrics(stats, 0)
        )

        #### replay 
        log_dict.update(
            _get_stage_train_metrics(stats, 1)
        )

        #### mutate 
        if config['use_accel']:
            log_dict.update(
                _get_stage_train_metrics(stats, 2)
            )

        # evaluation performance
        # returns = stats["eval_returns"]
        # # return 
        # log_dict.update({"eval/mean_returns": returns.mean()})
        # log_dict.update({"eval/max_returns": returns.max()})  
        # log_dict.update({"eval/min_returns": returns.min()})

        # # eps length
        # log_dict.update({"eval/mean_ep_lengths": stats["eval_ep_lengths"].mean()})
        # log_dict.update({"eval/max_ep_lengths": stats["eval_ep_lengths"].max()})
        # log_dict.update({"eval/min_ep_lengths": stats["eval_ep_lengths"].min()})

        # # achievements
        # log_dict.update(
        #     jax.tree_util.tree_map(
        #         lambda idx: stats["eval_achievements"].at[idx].get(),
        #         {
        #             f"eval/achievements_{ac.name.lower()}" : ac.value \
        #             for ac in Achievement
        #         }
        #     )
        # )

        # level sampler
        #log_dict.update(train_state_info["log"])

        # # images
        # log_dict.update(
        #     {
        #         "images/highest_scoring_level": wandb.Image(
        #             np.array(stats["highest_scoring_level"]),
        #             caption="Highest scoring level",
        #         )
        #     }
        # )
        # log_dict.update(
        #     {
        #         "images/highest_weighted_level": wandb.Image(
        #             np.array(stats["highest_weighted_level"]),
        #             caption="Highest weighted level",
        #         )
        #     }
        # )

        # for s in ["dr", "replay", "mutation"]:
        #     if train_state_info["info"][f"num_{s}_updates"] > 0:
        #         log_dict.update(
        #             {
        #                 f"images/{s}_levels": [
        #                     wandb.Image(np.array(image))
        #                     for image in stats[f"{s}_levels"]
        #                 ]
        #             }
        #         )

        # i = 0
        # frames, episode_length = (
        #     stats["eval_animation"][0][:, i],
        #     stats["eval_animation"][1][i],
        # )
        # frames = np.array(frames[:episode_length])
        # log_dict.update({f"animations/animation": wandb.Video(frames, fps=4)})
        return log_dict 

In [None]:
log_dict_ = log_train_eval(metrics, train_state)

stage_0


In [247]:
(
    (rng, train_state, _, _, _),
    (
        _,
        _,
        rewards,
        dones,
        _,
        values,
        info,
        advantages,
        _,
        losses,
        grad_norms,
    ),
) = sample_trajectories_and_learn(
    env,
    env_params,
    config,
    rng,
    train_state,
    ActorCritic.initialize_carry((config["num_train_envs"],)),
    init_obs,
    init_env_state,
    update_grad=config["exploratory_grad_updates"],
)

In [249]:
max_returns = jnp.maximum(
    level_sampler.get_levels_extra(
        train_state.sampler, 
        jnp.full((config["num_train_envs"],), fill_value=-1)
    )["max_return"],
    compute_max_returns(dones, rewards),
)

In [253]:
ued_scores = compute_ued_score(
    config, 
    dones, 
    values, 
    max_returns=max_returns, 
    advantages=advantages
)
ued_scores.shape

(4,)

In [41]:
dones.shape, rewards.shape, values.shape, advantages.shape# done = True (defat both, dead, run out time )

((4096, 4), (4096, 4), (4096, 4), (4096, 4))

((4096, 4),)

In [29]:
losses[0].shape

(320, 2)

In [19]:
dones[-1, :]

Array([False, False, False, False], dtype=bool)

In [81]:
dones.mean()
print(f"Total number of dones is {dones.sum()}")

Total number of dones is 62


In [30]:
grad_norms.shape

(320, 2)

In [238]:
info["floor"].shape

(4096, 4)

In [33]:
info["Achievements/cast_fireball"].shape

(4096, 4)

In [42]:
info["achievements"].shape, dones.shape

((4096, 4, 67), (4096, 4))

In [137]:

print(info["achievements"].sum())
print(
    (info["achievements"] * dones[..., None]).sum(axis=0).sum(axis=0) 
) # this gives the total number of achievements per category
achievement_per_exp = (info["achievements"] * dones[..., None]).sum(axis=0).sum(axis=0) / dones.sum()
print(
    achievement_per_exp
) # this gives the total number of achievements for each category, per completed episode

23068
[30  7  0 22  6  0  0 18  0  0  0  0  0  0  0 56  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
[0.48387095 0.11290322 0.         0.3548387  0.09677419 0.
 0.         0.29032257 0.         0.         0.         0.
 0.         0.         0.         0.9032258  0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.        ]


In [167]:
info["returned_episode_lengths"].sum()/dones.sum()

Array(251.08064, dtype=float32)

In [166]:
(info["returned_episode_lengths"] * dones).sum()

Array(15567, dtype=int32)

In [173]:
(info["returned_episode_returns"] * dones).sum(), (info["returned_episode_returns"]).sum()

(Array(83.2, dtype=float32), Array(83.2, dtype=float32))

In [139]:
info["Achievements/collect_drink"].sum()

Array(600., dtype=float32)

In [77]:
print(info["Achievements/place_table"].sum())
print(info["Achievements/place_table"].shape)
# such achievemt is calculated using 

700.0
(4096, 4)


In [174]:
grad_norms.shape

(320, 2)

In [133]:
achievement_per_exp.shape

(67,)

In [198]:
jax.tree_util.tree_map(
    lambda idx: achievement_per_exp.at[idx].get(),
    {
        f"eval/achievements_{ac.name.lower()}" : ac.value \
        for ac in Achievement
    }
    )

{'eval/achievements_cast_fireball': Array(0., dtype=float32),
 'eval/achievements_cast_iceball': Array(0., dtype=float32),
 'eval/achievements_collect_coal': Array(0., dtype=float32),
 'eval/achievements_collect_diamond': Array(0., dtype=float32),
 'eval/achievements_collect_drink': Array(0.09677419, dtype=float32),
 'eval/achievements_collect_iron': Array(0., dtype=float32),
 'eval/achievements_collect_ruby': Array(0., dtype=float32),
 'eval/achievements_collect_sapling': Array(0.3548387, dtype=float32),
 'eval/achievements_collect_sapphire': Array(0., dtype=float32),
 'eval/achievements_collect_stone': Array(0., dtype=float32),
 'eval/achievements_collect_wood': Array(0.48387095, dtype=float32),
 'eval/achievements_damage_necromancer': Array(0., dtype=float32),
 'eval/achievements_defeat_archer': Array(0., dtype=float32),
 'eval/achievements_defeat_deep_thing': Array(0., dtype=float32),
 'eval/achievements_defeat_fire_elemental': Array(0., dtype=float32),
 'eval/achievements_defeat_f

In [154]:
out

{'achievements/cast_fireball': Array(0., dtype=float32),
 'achievements/cast_iceball': Array(0., dtype=float32),
 'achievements/collect_coal': Array(0., dtype=float32),
 'achievements/collect_diamond': Array(0., dtype=float32),
 'achievements/collect_drink': Array(0.09677419, dtype=float32),
 'achievements/collect_iron': Array(0., dtype=float32),
 'achievements/collect_ruby': Array(0., dtype=float32),
 'achievements/collect_sapling': Array(0.3548387, dtype=float32),
 'achievements/collect_sapphire': Array(0., dtype=float32),
 'achievements/collect_stone': Array(0., dtype=float32),
 'achievements/collect_wood': Array(0.48387095, dtype=float32),
 'achievements/damage_necromancer': Array(0., dtype=float32),
 'achievements/defeat_archer': Array(0., dtype=float32),
 'achievements/defeat_deep_thing': Array(0., dtype=float32),
 'achievements/defeat_fire_elemental': Array(0., dtype=float32),
 'achievements/defeat_frost_troll': Array(0., dtype=float32),
 'achievements/defeat_gnome_archer': Arra

In [181]:
def func1(x):
    return {"plr":'2'}


def func2(x):
    return {"plr":'1'}


def func3(x):
    return {"plr":'0'}


slider = jnp.array([0,1,2])

jax.lax.switch(
    0,
    [
        func1, func2, func3
    ],
    [0,1,2]
)

TypeError: Value '2' with type <class 'str'> is not a valid JAX type

In [194]:
def _step(carry, _):

    x = carry 
    return x+1, {'loss': jnp.full((6,), fill_value=x*2)}

In [195]:
f_carry, xs = jax.lax.scan(
    _step,
    1,
    None ,
    4
)

In [196]:
f_carry

Array(5, dtype=int32, weak_type=True)

In [197]:
xs['loss'].shape 

(4, 6)

In [235]:
input_ = {
    "stage": jnp.array([0,1,1,2,1]), 
    "array":jnp.array(
        [
            [1,2,3,4], #0
            [0,0,0,0], # 1,  
            [-1,-1,-1,-1], #1 
            [4,4,4,4], #2,
            [1,1,1,1],
        ]
    )
}



#out
out = {
    "stage": jnp.array([1,1,]), 
    "array":jnp.array(
        [
            [0,0,0,0], # 1,  
            [-1,-1,-1,-1], #1 
        ]
    )
} 

def filter_by_stage(array:chex.Array, target_stage:int=1):
    # Get our mask for the target stage
    mask = array == target_stage
    
    # Calculate maximum possible matches (static size)
    max_size = array.shape[0]
    
    # Get indices that would sort True values first
    indices = jnp.argsort(~mask)
    
    # Reorder our arrays using these indices
    sorted_arrays = jnp.take(array, indices, axis=0)
    
    # Instead of slicing, we'll mask the values
    n_matches = jnp.sum(mask)
    valid_mask = jnp.arange(max_size) < n_matches
    out_array =  jnp.where(valid_mask[:, None], sorted_arrays, 0)

    return out_array.sum(axis=0)/n_matches

In [239]:
out["array"][out["stage"]==1]

Array([[ 0,  0,  0,  0],
       [-1, -1, -1, -1]], dtype=int32)

In [236]:
jax.tree_util.tree_map(
    filter_by_stage,
    input_,
)

ValueError: Incompatible shapes for broadcasting: shapes=[(5, 1), (5, 4, 4), ()]

In [222]:
filter_by_stage(input_, 1)

{'array': Array([[ 0,  0,  0,  0],
        [-1, -1, -1, -1],
        [ 1,  1,  1,  1],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]], dtype=int32),
 'avg': Array([0., 0., 0., 0.], dtype=float32),
 'n_matches': Array(3, dtype=int32)}