In [1]:
%load_ext autoreload
%autoreload 2

from typing import List, Tuple 

import os 
os.environ['JAX_PLATFORMS'] = 'cpu'

import tiktoken 
from langchain_anthropic import ChatAnthropic 

from craftax.craftax import game_logic 
from craftax.craftax import constants 
from craftax.craftax import craftax_state 
from craftax.craftax.envs import craftax_symbolic_env 
from craftax.craftax.world_gen import world_gen 

import jax 
import jax.numpy as jnp 

import matplotlib.pyplot as plt 

p = os.path.join(os.path.expanduser("~"), ".ssh", "anthropic.pem")
with open(p, "r") as f:
    anthropic_key = f.readline().strip()
f.close()
os.environ["ANTHROPIC_API_KEY"] = anthropic_key

model = ChatAnthropic(
    model="claude-3-5-sonnet-latest",
    temperature=0.2,
    max_tokens=4096,    
)

import sys 
import inspect 

Loading textures from cache.
Textures successfully loaded from cache.


# INPUT 
*** 

In [9]:
game_logic_code:str = inspect.getsource(game_logic)
constants_code:str = inspect.getsource(constants)
state_code = inspect.getsource(craftax_state)
env_code = inspect.getsource(craftax_symbolic_env)
entry_point = "CraftaxSymbolicEnv"
generation_code:str = inspect.getsource(world_gen) 


inputs = [
    game_logic_code, 
    constants_code, 
    state_code, 
    env_code, 
    generation_code
]

delimitier = "\n***************\n"
title = "Craftax Symbolic Environment game logic, states, environment and world generation code as follows:\n"

input_string = title + delimitier.join(inputs) + f"\nEntrey point of the enviornment is {entry_point}"

In [10]:
tokens = tiktoken.get_encoding("o200k_base").encode(input_string)
print(len(tokens))

45019


# Metric Discovery 
*** 

## Metric Generation 

In [27]:

messages = [
    (
        "system",
        "You are an exceptional Reinforcement Environment evaluator",
    ),
    (
        "human", 
        f"""
        ### Task 
        Think step by step to summairse a few key metrics comprehensively evaluating the learning potential for the following CraftaxSymbolicEnv Environment.
        Be as exhasutive as possible.
        ```python
        {input_string}
        ```
        After selecting the metrics, write a Python function that compute such (scalr) metric for an given enviornment 
        Make sure that the funciton takes in: 
            - The environment state representing an unplayed environment <EnvState> 

        Make sure that the function outputs: 
            - A scalar value that represents the metric value (between 0 and 1) <jnp.float32> 

        Please code in jax and put all functions in a self-contained script end to end.
        Make sure all functions have @jax.jit decorators.
        """
    ),
]

out = model.invoke(messages)

In [28]:
print(out.content)

Based on the provided code, I'll create a comprehensive evaluation function that measures key aspects of the learning potential for the CraftaxSymbolic environment. Here's my analysis and implementation:

Key Metrics to Consider:
1. Resource Accessibility (ores, trees, water)
2. Navigation Complexity (path connectivity)
3. Combat Opportunity (mob spawn potential)
4. Crafting Potential (proximity of crafting stations)
5. Survival Balance (food/water sources)

Here's the implementation:

```python
import jax
import jax.numpy as jnp
import chex
from functools import partial

@jax.jit
def compute_resource_density(map_layer: chex.Array) -> jnp.float32:
    """Compute density of important resources."""
    resource_blocks = jnp.array([
        BlockType.TREE.value,
        BlockType.COAL.value,
        BlockType.IRON.value,
        BlockType.DIAMOND.value,
        BlockType.SAPPHIRE.value,
        BlockType.RUBY.value
    ])
    
    total_resources = sum(map_layer == block for block in reso

In [31]:
messages.extend(
    [
        (
            "assistant", 
            out.content,
        ),
        (
            'human',
            "Think step by step to cover more evaluation metrics for the CraftaxSymbolicEnv environment, and code them up."
        )
    ]
)
out2 = model.invoke(messages) 

In [33]:
print(out2.content)

Let's break down additional important metrics for a more comprehensive evaluation:

1. Combat Environment Evaluation
2. Progression Potential
3. Risk-Reward Balance
4. Exploration Incentives
5. Resource Progression Chain
6. Boss Fight Preparation Potential

Here's the enhanced implementation:

```python
import jax
import jax.numpy as jnp
import chex
from functools import partial

@jax.jit
def compute_combat_environment(
    map_layer: chex.Array,
    mob_map_layer: chex.Array
) -> jnp.float32:
    """
    Evaluate combat environment quality.
    Considers:
    - Combat space availability
    - Mob density
    - Strategic positions (walls/obstacles for cover)
    """
    # Combat space calculation
    combat_viable_blocks = jnp.array([
        BlockType.PATH.value,
        BlockType.GRASS.value,
        BlockType.FIRE_GRASS.value,
        BlockType.ICE_GRASS.value
    ])
    
    combat_space = sum(map_layer == block for block in combat_viable_blocks)
    
    # Strategic positions (blo

In [34]:
messages.extend(
    [
        (
            "assistant", 
            out2.content,
        ),
        (
            'human',
            "Think step by step to reflect and see if you can make the compute_risk_reward_balance function more comprehensive. Also write one function that measures the combat challengness of the environment."
        )
    ]
)
out3 = model.invoke(messages) 

In [35]:
print(out3.content)

Let's enhance these aspects:

1. For risk-reward balance, we should consider:
- Mob difficulty vs reward value
- Environmental hazards vs resource richness
- Distance-based risk scaling
- Safe zone availability
- Resource accessibility difficulty

2. For combat challenge, we should measure:
- Mob type distribution and difficulty
- Combat space constraints
- Mob positioning and density
- Projectile combat potential
- Escape route availability

Here's the improved implementation:

```python
import jax
import jax.numpy as jnp
from functools import partial

@jax.jit
def compute_enhanced_risk_reward_balance(state: EnvState) -> jnp.float32:
    """
    Comprehensive evaluation of risk-reward balance.
    
    Considers:
    1. Resource value vs danger proximity
    2. Safe zones vs hazard zones
    3. Progressive difficulty scaling
    4. Resource accessibility
    5. Survival resource distribution
    """
    map_layer = state.map[state.player_level]
    mob_layer = state.mob_map[state.play

## Metric Testing

In [2]:
from examples.craftax import craftax_evaluation 
from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
from craftax.craftax.renderer import render_craftax_pixels as render_pixels
from craftax.craftax.world_gen.world_gen import generate_world as generate_world_craftax

from examples.craftax.craftax_wrappers import LogWrapper

from craftax.craftax.constants import Achievement
from jaxued.wrappers import AutoReplayWrapper
from jaxued.level_sampler import LevelSampler 

import jax 
import jax.numpy as jnp 

ENV_CLASS = CraftaxSymbolicEnv
generate_world = generate_world_craftax
render_craftax_pixels = render_pixels

DEFAULT_STATICS = ENV_CLASS.default_static_params()
default_env = ENV_CLASS(DEFAULT_STATICS)
env = LogWrapper(default_env)
env = AutoReplayWrapper(env)
env_params = env.default_params

def sample_random_level(rng):
    return generate_world(rng, env.default_params, DEFAULT_STATICS)

In [3]:
config = {
    'num_train_envs':4, 
    "level_buffer_capacity":100,
    "replay_prob":0.8,
    "staleness_coeff":0.3,
    "minimum_fill_ratio":0.1,
    "prioritization":"rank",
    "temperature":1.0,
    "topk_k":0.3,
    "max_grad_norm":1,
    "buffer_duplicate_check":True,
    "exploratory_grad_updates":False,
    "outer_rollout_steps":64,
    "num_steps":64,
    "gamma":0.99,
    "gae_lambda":0.9,
    "num_minibatches":2,
    "epoch_ppo":5,
    "clip_eps":0.2,
    "entropy_coeff":0.01,
    "critic_coeff":0.2,
    "num_updates":10,
    "lr":3e-04,
    "score_function":"MaxMC",
    "eval_freq":5,
    "use_accel":True,
    "num_edits":1,
}

In [4]:
rng = jax.random.PRNGKey(1)

rng, rng_levels, rng_reset = jax.random.split(rng, 3)
new_levels = jax.vmap(sample_random_level)(
    jax.random.split(rng_levels, config["num_train_envs"])
)
print(type(new_levels))

<class 'craftax.craftax.craftax_state.EnvState'>


In [5]:
new_levels.map.size 

82944

In [6]:
new_levels.player_level

Array([0, 0, 0, 0], dtype=int32)

In [7]:
new_levels.melee_mobs.health.shape

(4, 9, 3)

In [8]:
4 * 9 * 48 * 48 

82944

In [9]:
jax.vmap(craftax_evaluation.compute_natural_resource_density_all_level)(new_levels)

Array([0.18258102, 0.17472029, 0.19449267, 0.18696952], dtype=float32)

In [10]:
jax.vmap(craftax_evaluation.compute_path_density_all_level)(new_levels)

Array([0.29026812, 0.31052276, 0.31110147, 0.30709878], dtype=float32)

In [11]:
jax.vmap(craftax_evaluation.compute_survival_resource_density_all_level)(new_levels)

Array([0.5005787 , 0.4378858 , 0.40147567, 0.3915895 ],      dtype=float32, weak_type=True)

In [12]:
jax.vmap(craftax_evaluation.compute_crafting_potential_all_level)(new_levels)

Array([0.00192901, 0.00192901, 0.00192901, 0.00192901],      dtype=float32, weak_type=True)

In [13]:
jax.vmap(craftax_evaluation.compute_progression_potential)(new_levels)

Array([0.19173178, 0.15581597, 0.13020834, 0.13617623], dtype=float32)

In [14]:
new_levels.melee_mobs.position.shape 

(4, 9, 3, 2)

In [19]:
print(jax.vmap(craftax_evaluation.compute_exploration_incentives)(new_levels))
print(new_levels.map.shape)

(48, 48)
(48, 48)
[0.66407037 0.6713666  0.68036294 0.6667115 ]
(4, 9, 48, 48)


In [34]:
print(jax.vmap(craftax_evaluation.compute_mob_challengeness)(new_levels))
print(new_levels.map.shape)

[3. 3. 3. 3.]
(4, 9, 48, 48)


In [25]:
new_levels.ranged_mobs.mask.shape 

(4, 9, 2)

In [55]:
from examples.craftax.craftax_plr import evaluate_levels

In [56]:
evaluate_levels(new_levels)

{'compute_crafting_potential_all_level': Array(0.00192901, dtype=float32),
 'compute_mob_challengeness': Array(3., dtype=float32),
 'compute_natural_resource_density_all_level': Array(0.18469088, dtype=float32),
 'compute_path_density_all_level': Array(0.3047478, dtype=float32),
 'compute_progression_potential': Array(0.1534831, dtype=float32),
 'compute_survival_resource_density_all_level': Array(0.43288243, dtype=float32)}

In [43]:
new_levels.player_level

Array([0, 0, 0, 0], dtype=int32)

# Editor Discovery 
*** 

In [21]:

messages = [
    (
        "system",
        "You are an exceptional Reinforcement Environment Engineer for an Underspecified POMDP",
    ),
    (
        "human", 
        f"""
        ### Task 
        Think step by step to write a mixture of editors for the fowlloing environment written in JAX.
        ```python
        {input_string}
        ```

        ### Task
        Think step by step to create a mixture of editor functions for the above environment in JAX. 
        Here are some fundamental principles for designing the mixture of editors:

        1.	Validity and Integrity: Maintain the environment's validity by respecting its rules and constraints.

        2.  State Coverage: Focusing on editing the states that would be more likely to encourage the agent to explore different parts of the environment or achieve different objectives.
        
        3.	Minimal Variability: Introduce the minimum necessary operations required to increase or decrease the difficulty of the environment.
        
        4.	Controlled Randomness: Use randomness thoughtfully to expand the agent's experiences while keeping modifications within acceptable bounds.
        
        5.	Agent-Centric Considerations: Ensure that modifications do not adversely affect the agent's ability to perceive and interact with the environment coherently.
        
        6.	Efficiency and Performance: Utilize computational techniques (i.e. masking and immutable updates) that align with high-performance computing frameworks, ensuring that the environment can be modified efficiently.
            You must kept any iteratitve operation at 10 tries maximum.

        ### Requirement
        Each function ONLY takes in: 
        - a chex.PRNGKey and 
        - a EnvState object as input and returns a modified EnvState object.

        ### Useage
        Editors will be used in an evolutionary process of the environment States.
        Different editors will be selected to form a sequence of operations upon receiving an environment State.
        You can deisgn the mixture given you such fact, but do not chain different editors together yourself.

        ### Output Format
        ```txt
        # scratchpad  (text this section, no code)
        # ...
        # ...
        # ... 
        # ...
        ```
        ===============================

        ```Python
        # imports 
        import jax
        import chex 
        from functools import partial
        ...

        # env instance
        ...

        # utility functions
        ... 

        # mixture of editor
        @jax.jit
        def editor_{{editor_name}}(rng:chex.PRNGKey, env_state:EnvState) -> EnvState:
            ... doc string ...
            .... code ...
        
        @jax.jit)
        def editor_{{editor_name}}(rng:chex.PRNGKey, env_state:EnvState) -> EnvState:
            ... doc string ...
            .... code ...
        ...
        
        ```
        """
    ),
]

out = model.invoke(messages)

In [22]:
print(out.content)

Let me help you design a mixture of editors for the Craftax environment step by step.

```txt
# scratchpad

1. Analyzing Core State Components:
- Map and items (spatial layout)
- Player attributes (health, food, etc.)
- Inventory (resources and tools)
- Mobs (enemies and passive)
- Environment conditions (light level, etc.)

2. Key Editor Categories:
a) Resource Editors:
- Modify inventory levels
- Adjust resource distribution on map
- Change tool/weapon levels

b) Challenge Editors:
- Modify mob distributions
- Adjust mob attributes
- Change environmental hazards

c) Progression Editors:
- Modify player attributes
- Adjust achievement states
- Change level accessibility

3. Design Considerations:
- Use masking for valid state updates
- Employ jax.random for controlled randomness
- Limit iterations to 10 tries max
- Maintain state validity constraints
- Focus on high-impact modifications

4. Editor Priorities:
1. Basic resource availability
2. Combat difficulty
3. Environmental challen