diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 7e85d8f595..2ae7e5f8e5 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -20,7 +20,6 @@ import jax.numpy as jnp import optax -from maxtext.utils import max_logging from optax.contrib._muon import muon from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers @@ -140,8 +139,7 @@ def do_update(): return inner_opt.update(updates, state["inner_state"], params, **extra_args) def skip_update(): - # use callback to work with jax.jit and jax.lax.cond for logging - jax.debug.callback(lambda c: max_logging.warning(f"Step {c}: Optimizer step skipped due to spike."), count) + # b/500923599: Investigate logging compatible with jax.jit, jax.lax.cond, and Pathway inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates) return inner_updates, state["inner_state"] diff --git a/tests/unit/deepseek_scan_engram_test.py b/tests/unit/deepseek_scan_engram_test.py index a84e6909fa..ff51c5c62f 100644 --- a/tests/unit/deepseek_scan_engram_test.py +++ b/tests/unit/deepseek_scan_engram_test.py @@ -53,15 +53,26 @@ class TestDeepSeekScanEngram(unittest.TestCase): "first_num_dense_layers=5", "base_num_decoder_layers=10", "num_decoder_layers=10", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "head_dim=32", + "indexer_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "v_head_dim=32", + "vocab_size=128", "mhc_expansion_rate=4", "attention=dot_product", - "per_device_batch_size=2", + "per_device_batch_size=1", "max_target_length=8", "max_prefill_predict_length=8", "enable_checkpointing=False", "engram_num_heads=1", - "engram_head_dim=32", - "engram_vocab_bases=[226240,226240]", + "engram_head_dim=8", + "engram_vocab_bases=[128,128]", "engram_max_ngram_size=3", "engram_kernel_size=4", "hf_access_token=dummy", @@ -78,7 +89,7 @@ class MockTokenizer: pad_token_id = 0 def __len__(self): - return 1000 + return 128 def __call__(self, x): return jnp.ones_like(x)