In [18]:
%load_ext autoreload
%autoreload 2

import os  
import sys 
from PIL import Image 
import json 
os.environ['JAX_PLATFORMS'] = 'cpu'

import flax
import flax.linen as nn
import jax 
from jax import random
import jax2d
import jax.numpy as jnp 
import chex

import optax  
from kinetix.environment.env_state import EnvParams, StaticEnvParams
from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.ued.ued import UEDParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.models.transformer_model import ActorCriticTransformer
from kinetix.models.actor_critic import ScannedRNN 
from kinetix.models import make_network_from_config
from kinetix.render.renderer_symbolic_entity import EntityObservation

from jaxued.wrappers.autoreplay import AutoReplayWrapper

from editax.moed import EditorManager
from editax.learning import (
    EditorPolicyTrainState, 
    update_editr_actor_critic_rnn,
)
from editax.policy.lstm import EditorActorCritic, ResetLSTM
from editax.upomdp import LogWrapper
from editors.kinetix.editor_o3_mini_8inner_v1 import (
    mmp_cluster_shapes,
    mmp_increase_spacing, 
    mmp_reduce_density,
    mmp_increase_density,
    mmp_disable_motor_auto,
    mmp_enable_motor_auto
)

config = {
    "num_minibatches": 8,
    "update_epochs": 4,
    "num_updates": 1,
    "outer_rollout_steps": 4,
    "num_steps": 256,
    "num_train_envs": 32,
    "anneal_lr": True,
    "lr": 1e-4,
    "max_grad_norm": 1.0,
    "transformer_depth": 2,
    "transformer_size": 16,
    "transformer_encoder_size": 128,
    "num_heads": 8,
    "full_attention_mask": False,
    "aggregate_mode": "dummy_and_mean",
    "fc_layer_depth": 5,
    "fc_layer_width": 128,
    "activation": "tanh",
    "recurrent_model":True,
    "env_name": "Kinetix-Entity-MultiDiscrete-v1",
}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load 
*** 

In [3]:
env_params = EnvParams()
static_env_params = StaticEnvParams()
ued_params = UEDParams()

# Create the environment
env = make_kinetix_env_from_args(
    obs_type="entity",
    action_type="multidiscrete",
    reset_type="replay",
    static_env_params=static_env_params,
)

# Sample a random level
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
level = sample_kinetix_level(
    _rng, 
    env.physics_engine, 
    env_params, 
    static_env_params, 
    ued_params, 
    env_size_name="m"
)
print(type(level))

<class 'kinetix.environment.env_state.EnvState'>


In [4]:
level.__dict__.keys()

dict_keys(['polygon', 'circle', 'joint', 'thruster', 'collision_matrix', 'acc_rr_manifolds', 'acc_cr_manifolds', 'acc_cc_manifolds', 'gravity', 'thruster_bindings', 'motor_bindings', 'motor_auto', 'polygon_shape_roles', 'circle_shape_roles', 'polygon_highlighted', 'circle_highlighted', 'polygon_densities', 'circle_densities', 'timestep'])

In [5]:
# Reset the environment state to this level
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)

In [6]:
print(obs.circles.shape)

(12, 19)


In [7]:
batch_size = 32
seq_len = 10

obs = jax.tree_util.tree_map(
    lambda x: jnp.repeat(
        jnp.repeat(
            x[jnp.newaxis, ...], 
            batch_size, 
            axis=0
        )[jnp.newaxis, ...], 
        seq_len, 
        axis=0
    ),
    obs
)

print(obs.circles.shape)

(10, 32, 12, 19)


# RNN
*** 

In [8]:
n_editor = 4 
out_feat = 256
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

model = ResetLSTM(nn.OptimizedLSTMCell(features=out_feat))
print(model)

ResetLSTM(
    # attributes
    cell = OptimizedLSTMCell(
        # attributes
        features = 256
        gate_fn = sigmoid
        activation_fn = tanh
        kernel_init = init
        recurrent_kernel_init = init
        bias_init = zeros
        dtype = None
        param_dtype = float32
        carry_init = zeros
    )
)


In [9]:
embeds = jax.random.uniform(key_2, (seq_len, batch_size, 19))
dones = jax.random.uniform(key_2, (seq_len, batch_size, )) > 0.5

xs = (embeds, dones)
print(xs[0].shape)
print(xs[1].shape)

(10, 32, 19)
(10, 32)


In [10]:
init_carry = model.cell.initialize_carry(key_3, xs[0].shape[1:])
print(init_carry[0].shape)
print(init_carry[1].shape)


(32, 256)
(32, 256)


In [11]:
variables = model.init(key_3, xs)

In [12]:
out_carry, out_val = model.apply(variables, xs, initial_carry=init_carry)
print(out_carry[0].shape)
print(out_carry[1].shape)
print(out_val.shape)

(32, 256)
(32, 256)
(10, 32, 256)


# Policy 
*** 

In [21]:
policy = network = ActorCriticTransformer(
    action_dim=(n_editor,),
    fc_layer_width=config["fc_layer_width"],
    fc_layer_depth=config["fc_layer_depth"],
    action_mode="discrete",
    num_heads=config["num_heads"],
    transformer_depth=config["transformer_depth"],
    transformer_size=config["transformer_size"],
    transformer_encoder_size=config["transformer_encoder_size"],
    aggregate_mode=config["aggregate_mode"],
    full_attention_mask=config["full_attention_mask"],
    activation=config["activation"],
    **{
        "hybrid_action_continuous_dim": (n_editor,),
        "multi_discrete_number_of_dims_per_distribution": [n_editor],
        "recurrent": True,
    }
)

In [20]:
init_x = (
    obs, 
    jnp.zeros(
        (seq_len, config["num_train_envs"]), dtype=jnp.bool_)
)
network_params = network.init(
    _rng, 
    ScannedRNN.initialize_carry(config["num_train_envs"]), 
    init_x
)

TypeError: unsupported operand type(s) for //: 'tuple' and 'tuple'

In [14]:
rng, subrng = random.split(key_3)
xs = (obs, dones)
policy_variables = policy.init(subrng, policy_init_carry, xs)
policy_out = policy.apply(policy_variables, policy_init_carry, xs)

TypeError: where requires ndarray or scalar arguments, got <class 'tuple'> at position 2.

In [19]:
print(f" hidden state: {policy_out[0][0].shape}")
print(f" policy: {policy_out[1].logits.shape}")
print(f" value: {policy_out[2].shape}")

 hidden state: (32, 256)
 policy: (10, 32, 4)
 value: (10, 32)


# Sample 
*** 

In [23]:
@jax.jit
def create_train_state(rng:chex.PRNGKey) -> EditorPolicyTrainState:
    # Creates the train state
    def linear_schedule(count):
        frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / (
            config["num_updates"] * config["outer_rollout_steps"]
        )
        return config["lr"] * frac

    init_x = (obs, jnp.zeros((seq_len, config["num_train_envs"]), dtype=jnp.bool_))
    network = EditorActorCritic(hidden_dim=out_feat, action_dim=n_editor)    
    rng, _rng = jax.random.split(rng)
    network_params = network.init(
        _rng, 
        init_x,
        EditorActorCritic.initialize_carry(
            (config["num_train_envs"],), 
            out_feat
        ), 
    )

    if config["anneal_lr"]:
        tx = optax.chain(
            optax.clip_by_global_norm(config["max_grad_norm"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(
            optax.clip_by_global_norm(config["max_grad_norm"]),
            optax.adam(config["lr"], eps=1e-5),
        )

    train_state = EditorPolicyTrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
        num_updates=0,
    )
    return train_state

In [55]:
train_state = create_train_state(rng)
rng, subrng = random.split(rng)

num_edits = 8
edit_eps_length = 256
num_envs = 32 

editors = [
    mmp_cluster_shapes,
    mmp_increase_spacing, 
    mmp_reduce_density,
    mmp_increase_density,
    mmp_disable_motor_auto,
    mmp_enable_motor_auto
]

with open("experiments/kinetix.json", "r") as f:
    editor_config = json.load(f)

obs, env_state = env.reset_to_level(_rng, level, env_params)

editor_manager = EditorManager(**editor_config)
_ = editor_manager.reset(
    env_state,
    8,
)

2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,661 - editax.moed - INFO - File name: editor_o3_mini_8inner_v1
2025-02-23 19:50:54,664 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_8inner_v1_tmp_0.py
2025-02-23 19:50:54,664 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_8inner_v1_tmp_0.py
2025-02-23 19:50:54,664 - editax.moed - INFO - Init editors -> editors/kinetix/editor_o3_mini_8inner_v1_tmp_0.py
2025-02-23 19:50:54,664 - editax.moed - INFO - Init editors -> editors/kinetix/e

KeyboardInterrupt: 

In [53]:
rng, subrng = random.split(rng)

init_hstate = EditorActorCritic.initialize_carry(
    (config["num_train_envs"],), 
    out_feat
)
print(f"carry shape: {init_hstate[0].shape}")
print(f"carry shape: {init_hstate[1].shape}")
init_env_state = obs[0]
print(f"x shape: {init_env_state.shape}")

(rng, train_state, hstate, last_env_state, last_value), traj = editor_manager.sample_edit_trajectories_rnn(
        rng,
        train_state,
        init_hstate, 
        init_env_state,
        num_envs,
        edit_eps_length,
    )

carry shape: (32, 256)
carry shape: (32, 256)
x shape: (32, 1675)
(1, 32, 1675)
(1, 32)


AttributeError: 'EditorManager' object has no attribute 'editors'

# Update 
*** 

# Eval / Test 
*** 