In [1]:
%load_ext autoreload
%autoreload 2

import os  
import sys 
from PIL import Image 
import json 
os.environ['JAX_PLATFORMS'] = 'cpu'

import flax
import flax.linen as nn
import jax 
from jax import random
import jax2d
import jax.numpy as jnp 
import chex

import optax  
from kinetix.environment.env_state import EnvParams, StaticEnvParams
from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.ued.ued import UEDParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.models.transformer_model import ActorCriticTransformer
from kinetix.models.actor_critic import ScannedRNN 
from kinetix.models import make_network_from_config
from kinetix.render.renderer_symbolic_entity import EntityObservation

from jaxued.wrappers.autoreplay import AutoReplayWrapper

from editax.moed import EditorManager
from editax.learning import (
    EditorPolicyTrainState, 
    update_editr_actor_critic_rnn,
)
from editax.policy.lstm import EditorActorCritic, ResetLSTM
from editax.upomdp import LogWrapper
from editors.kinetix.editor_o3_mini_8inner_v1 import (
    mmp_cluster_shapes,
    mmp_increase_spacing, 
    mmp_reduce_density,
    mmp_increase_density,
    mmp_disable_motor_auto,
    mmp_enable_motor_auto
)

config = {
    "num_minibatches": 8,
    "update_epochs": 4,
    "num_updates": 1,
    "outer_rollout_steps": 4,
    "num_steps": 256,
    "num_train_envs": 32,
    "anneal_lr": True,
    "lr": 1e-4,
    "max_grad_norm": 1.0,
    "transformer_depth": 2,
    "transformer_size": 16,
    "transformer_encoder_size": 128,
    "num_heads": 8,
    "full_attention_mask": False,
    "aggregate_mode": "dummy_and_mean",
    "fc_layer_depth": 5,
    "fc_layer_width": 128,
    "activation": "tanh",
    "recurrent_model":True,
    "env_name": "Kinetix-Entity-MultiDiscrete-v1",
}

# Load 
*** 

In [3]:
env_params = EnvParams()
static_env_params = StaticEnvParams()
ued_params = UEDParams()

# Create the environment
env = make_kinetix_env_from_args(
    obs_type="entity",
    action_type="multidiscrete",
    reset_type="replay",
    static_env_params=static_env_params,
)

# Sample a random level
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
level = sample_kinetix_level(
    _rng, 
    env.physics_engine, 
    env_params, 
    static_env_params, 
    ued_params, 
    env_size_name="m"
)
print(type(level))

<class 'kinetix.environment.env_state.EnvState'>


In [4]:
level.__dict__.keys()

dict_keys(['polygon', 'circle', 'joint', 'thruster', 'collision_matrix', 'acc_rr_manifolds', 'acc_cr_manifolds', 'acc_cc_manifolds', 'gravity', 'thruster_bindings', 'motor_bindings', 'motor_auto', 'polygon_shape_roles', 'circle_shape_roles', 'polygon_highlighted', 'circle_highlighted', 'polygon_densities', 'circle_densities', 'timestep'])

In [5]:
# Reset the environment state to this level
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)

In [7]:
print(type(env_state))
print(type(obs))

<class 'kinetix.environment.wrappers.LogEnvState'>
<class 'kinetix.render.renderer_symbolic_entity.EntityObservation'>


In [13]:
env_state.env_state.env_state.env_state.__dict__.keys()

dict_keys(['polygon', 'circle', 'joint', 'thruster', 'collision_matrix', 'acc_rr_manifolds', 'acc_cr_manifolds', 'acc_cc_manifolds', 'gravity', 'thruster_bindings', 'motor_bindings', 'motor_auto', 'polygon_shape_roles', 'circle_shape_roles', 'polygon_highlighted', 'circle_highlighted', 'polygon_densities', 'circle_densities', 'timestep'])

In [14]:
obs.__dict__.keys()

dict_keys(['circles', 'polygons', 'joints', 'thrusters', 'circle_mask', 'polygon_mask', 'joint_mask', 'thruster_mask', 'attention_mask', 'joint_indexes', 'thruster_indexes'])

In [17]:
batch_size = 32
seq_len = 10

obs = jax.tree_util.tree_map(
    lambda x: jnp.repeat(
        jnp.repeat(
            x[jnp.newaxis, ...], 
            batch_size, 
            axis=0
        )[jnp.newaxis, ...], 
        seq_len, 
        axis=0
    ),
    obs
)

print(obs.circles.shape)
# seq_len, batch_size, ... 

(10, 32, 12, 19)


# RNN
*** 

In [18]:
n_editor = 4 
out_feat = 256
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

model = ResetLSTM(nn.OptimizedLSTMCell(features=out_feat))
print(model)

ResetLSTM(
    # attributes
    cell = OptimizedLSTMCell(
        # attributes
        features = 256
        gate_fn = sigmoid
        activation_fn = tanh
        kernel_init = init
        recurrent_kernel_init = init
        bias_init = zeros
        dtype = None
        param_dtype = float32
        carry_init = zeros
    )
)


In [9]:
embeds = jax.random.uniform(key_2, (seq_len, batch_size, 19))
dones = jax.random.uniform(key_2, (seq_len, batch_size, )) > 0.5

xs = (embeds, dones)
print(xs[0].shape)
print(xs[1].shape)

(10, 32, 19)
(10, 32)


In [10]:
init_carry = model.cell.initialize_carry(key_3, xs[0].shape[1:])
print(init_carry[0].shape)
print(init_carry[1].shape)


(32, 256)
(32, 256)


In [11]:
variables = model.init(key_3, xs)

In [12]:
out_carry, out_val = model.apply(variables, xs, initial_carry=init_carry)
print(out_carry[0].shape)
print(out_carry[1].shape)
print(out_val.shape)

(32, 256)
(32, 256)
(10, 32, 256)


# Policy 
*** 

In [24]:
n_editor = 4 

policy = ActorCriticTransformer(
    action_dim=(n_editor),
    fc_layer_width=config["fc_layer_width"],
    fc_layer_depth=config["fc_layer_depth"],
    action_mode="discrete",
    num_heads=config["num_heads"],
    transformer_depth=config["transformer_depth"],
    transformer_size=config["transformer_size"],
    transformer_encoder_size=config["transformer_encoder_size"],
    aggregate_mode=config["aggregate_mode"],
    full_attention_mask=config["full_attention_mask"],
    activation=config["activation"],
    **{
        "hybrid_action_continuous_dim": (n_editor),
        "multi_discrete_number_of_dims_per_distribution": [n_editor],
        "recurrent": True,
    }
)

In [28]:
policy_init_carry = ScannedRNN.initialize_carry(config["num_train_envs"])
print(policy_init_carry.shape)
# num_batch_envs, hidden_dim 

dones = jnp.zeros((seq_len, config["num_train_envs"]), dtype=jnp.bool_)
print(dones.shape)

(32, 256)
(10, 32)


In [29]:
init_x = (
    obs, 
    jnp.zeros(
        (seq_len, config["num_train_envs"]), dtype=jnp.bool_)
)
network_params = policy.init(
    _rng, 
    policy_init_carry,
    init_x
)

In [35]:
rng, subrng = random.split(key_3)
xs = (obs, dones)

policy_out = policy.apply(network_params, policy_init_carry, xs)

In [36]:
print(f" hidden state: {policy_out[0][0].shape}")
print(f" policy: {policy_out[1].logits.shape}")
print(f" value: {policy_out[2].shape}")

 hidden state: (256,)
 policy: (10, 32, 4)
 value: (10, 32)


In [37]:
editor_idx = policy_out[1].sample(seed=subrng)
editor_idx.shape

(10, 32)

# Sample 
*** 

In [44]:
@jax.jit
def create_editor_policy_train_state(rng:chex.PRNGKey) -> EditorPolicyTrainState:

    rng, _rng = jax.random.split(rng)
    init_x = (
        obs, 
        jnp.zeros(
            (seq_len, config["num_train_envs"]), dtype=jnp.bool_)
    )
    print(init_x[0].circles.shape)
    network_params = policy.init(
        _rng, 
        policy_init_carry,
        init_x
    )
    
    tx = optax.chain(
        optax.clip_by_global_norm(config["max_grad_norm"]),
        optax.adam(config["lr"], eps=1e-5),
    )

    editor_policy_train_state = EditorPolicyTrainState.create(
        apply_fn=policy.apply,
        params=network_params,
        tx=tx,
        num_updates=0,
    )
    return editor_policy_train_state

In [None]:
# Reset the environment state to this level
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)

# reset the shape 
batch_size = 32 
seq_len = 1
obs = jax.tree_util.tree_map(
    lambda x: jnp.repeat(
        jnp.repeat(
            x[jnp.newaxis, ...], 
            batch_size, 
            axis=0
        )[jnp.newaxis, ...], 
        seq_len, 
        axis=0
    ),
    obs
)

# create train state 
rng, _rng = random.split(rng)
train_state = create_editor_policy_train_state(_rng)

num_edits = 8
edit_eps_length = 8 * 10
num_envs = batch_size
with open("experiments/kinetix.json", "r") as f:
    editor_config = json.load(f)
editor_config["init_editors"] = False
num_inner_loops = editor_config.pop("num_inner_loops", 10)

editor_manager = EditorManager(**editor_config)
_ = editor_manager.reset(env_state, num_inner_loops)

2025-03-10 10:27:06,542 - editax.moed - INFO - File name: editor_o3_mini_10inner_comprehensive_new_v2
2025-03-10 10:27:06,542 - editax.moed - INFO - File name: editor_o3_mini_10inner_comprehensive_new_v2
2025-03-10 10:27:06,542 - editax.moed - INFO - File name: editor_o3_mini_10inner_comprehensive_new_v2
2025-03-10 10:27:06,543 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_10inner_comprehensive_new_v2_tmp_0.py
2025-03-10 10:27:06,543 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_10inner_comprehensive_new_v2_tmp_0.py
2025-03-10 10:27:06,543 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_10inner_comprehensive_new_v2_tmp_0.py
2025-03-10 10:27:06,544 - editax.utils - INFO - editors.kinetix.editor_o3_mini_10inner_comprehensive_new_v2_tmp_0
2025-03-10 10:27:06,550 - editax.utils - INFO - Loaded 10 functions
2025-03-10 10:27:06,561 - editax.utils - INFO - Editor mmp_dim_goal failed the test
2025-03-10 10:27:06,567 - editax

KeyboardInterrupt: 

In [53]:
rng, subrng = random.split(rng)

init_hstate = EditorActorCritic.initialize_carry(
    (config["num_train_envs"],), 
    out_feat
)
print(f"carry shape: {init_hstate[0].shape}")
print(f"carry shape: {init_hstate[1].shape}")
init_env_state = obs[0]
print(f"x shape: {init_env_state.shape}")

(rng, train_state, hstate, last_env_state, last_value), traj = editor_manager.sample_edit_trajectories_rnn(
        rng,
        train_state,
        init_hstate, 
        init_env_state,
        num_envs,
        edit_eps_length,
    )

carry shape: (32, 256)
carry shape: (32, 256)
x shape: (32, 1675)
(1, 32, 1675)
(1, 32)


AttributeError: 'EditorManager' object has no attribute 'editors'

# Update 
*** 

# Eval / Test 
*** 