In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import  FlattenObservationWrapper
import wandb


# Custom TrainState to track step count
class CustomTrainState(TrainState):
    step_count: int


from wrappers import (
    LogWrapper,
    OptimisticResetVecEnvWrapper,
    AutoResetEnvWrapper,
    BatchEnvWrapper,
    ThinkingWrapper,
)


ModuleNotFoundError: No module named 'distutils'

In [2]:
config["NUM_UPDATES"] = (
    config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
    config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)

if(config["ENV_NAME"]=="craftax"):
    from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnvNoAutoReset
    env=CraftaxSymbolicEnvNoAutoReset()
    env_params=env.default_params
    action_dim_env = env.action_space(env_params).n
    env = LogWrapper(env)
    env = ThinkingWrapper(env, action_dim_env, config["R_THINK"])
    env = OptimisticResetVecEnvWrapper(
            env,
            num_envs=config["NUM_ENVS"],
            reset_ratio=min(16, config["NUM_ENVS"]),
        )
else:
    env, env_params = gymnax.make(config["ENV_NAME"])
    action_dim_env = env.action_space(env_params).n 
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)
    env = ThinkingWrapper(env, action_dim_env, config["R_THINK"])
    env = BatchEnvWrapper(env,config["NUM_ENVS"])


NameError: name 'config' is not defined

In [14]:
# -------------------------------
# synthetic_masking_test.py
# -------------------------------
import jax, jax.numpy as jnp
import distrax

# ---------------- hyper-params ----------------
BATCH            = 2          # parallel envs
ACTION_ENV       = 2          # genuine env actions
THINKING_TOKENS  = 4          # “thinking” actions
TOTAL_ACTIONS    = ACTION_ENV + THINKING_TOKENS
MAX_THINKING_LEN = 3          # cap on consecutive thinking steps
# ---------------------------------------------

key = jax.random.PRNGKey(84)

# Fake network output: random logits (shape [B, A])
logits = jax.random.normal(key, (BATCH, TOTAL_ACTIONS))

# Fake env_state.thinking_streak (shape [B])
consecutive_streak = jnp.array([0, 6])  # last env already over the limit

# ---------- masking logic ----------
mask = ((consecutive_streak[:, None] < MAX_THINKING_LEN) |
        (jnp.arange(TOTAL_ACTIONS) < ACTION_ENV))

masked_logits = jnp.where(mask, logits, -jnp.inf)   # -∞ kills prob
pi = distrax.Categorical(logits=masked_logits)

# ---------- sample & score ----------
sample_key, _ = jax.random.split(key)
action    = pi.sample(seed=sample_key)
log_prob  = pi.log_prob(action)
probs     = pi.probs  # nice to inspect

# ---------- print results ----------
print("raw logits:\n", logits)
print("\nallowed mask:\n", mask)
print("\nmasked logits:\n", masked_logits)
print("\naction probabilities:\n", probs)
print("\nsampled actions:", action)
print("log-probs:", log_prob)
print("probs:", jnp.exp(log_prob))


raw logits:
 [[-1.887619   -0.17193942 -0.7739322  -1.6001499  -1.3864341  -0.05387902]
 [ 0.31428617 -0.9276049  -0.03227857  0.14984944  1.2719766  -0.04556949]]

allowed mask:
 [[ True  True  True  True  True  True]
 [ True  True False False False False]]

masked logits:
 [[-1.887619   -0.17193942 -0.7739322  -1.6001499  -1.3864341  -0.05387902]
 [ 0.31428617 -0.9276049         -inf        -inf        -inf        -inf]]

action probabilities:
 [[0.05305887 0.29503137 0.1615943  0.0707301  0.08758301 0.33200237]
 [0.77589303 0.22410698 0.         0.         0.         0.        ]]

sampled actions: [2 1]
log-probs: [-1.8226664 -1.4956317]
probs: [0.1615943 0.224107 ]
