Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import jax.numpy as jnp

import optax
from maxtext.utils import max_logging
from optax.contrib._muon import muon
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers

Expand Down Expand Up @@ -140,8 +139,7 @@ def do_update():
return inner_opt.update(updates, state["inner_state"], params, **extra_args)

def skip_update():
# use callback to work with jax.jit and jax.lax.cond for logging
jax.debug.callback(lambda c: max_logging.warning(f"Step {c}: Optimizer step skipped due to spike."), count)
# b/500923599: Investigate logging compatible with jax.jit, jax.lax.cond, and Pathway
inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
return inner_updates, state["inner_state"]

Expand Down
19 changes: 15 additions & 4 deletions tests/unit/deepseek_scan_engram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,26 @@ class TestDeepSeekScanEngram(unittest.TestCase):
"first_num_dense_layers=5",
"base_num_decoder_layers=10",
"num_decoder_layers=10",
"base_emb_dim=64",
"base_mlp_dim=64",
"base_moe_mlp_dim=64",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"head_dim=32",
"indexer_head_dim=32",
"qk_nope_head_dim=32",
"qk_rope_head_dim=16",
"v_head_dim=32",
"vocab_size=128",
"mhc_expansion_rate=4",
"attention=dot_product",
"per_device_batch_size=2",
"per_device_batch_size=1",
"max_target_length=8",
"max_prefill_predict_length=8",
"enable_checkpointing=False",
"engram_num_heads=1",
"engram_head_dim=32",
"engram_vocab_bases=[226240,226240]",
"engram_head_dim=8",
"engram_vocab_bases=[128,128]",
"engram_max_ngram_size=3",
"engram_kernel_size=4",
"hf_access_token=dummy",
Expand All @@ -78,7 +89,7 @@ class MockTokenizer:
pad_token_id = 0

def __len__(self):
return 1000
return 128

def __call__(self, x):
return jnp.ones_like(x)
Expand Down
Loading