Skip to content

Commit 4f008c5

Browse files
Fp8 unit test error fix
1 parent 02b0689 commit 4f008c5

3 files changed

Lines changed: 13 additions & 9 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from jax.sharding import Mesh
3131

3232
from maxtext.common.common_types import (
33-
EP_AS_CONTEXT,
3433
MODEL_MODE_AUTOREGRESSIVE,
3534
MODEL_MODE_PREFILL,
3635
MODEL_MODE_TRAIN,
@@ -171,8 +170,6 @@ def __call__(
171170

172171
if self.model_mode == MODEL_MODE_PREFILL:
173172
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
174-
elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN:
175-
logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed")
176173
else:
177174
logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed")
178175

src/maxtext/layers/nnx_wrappers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,18 @@ def __call__(
286286
# TODO(cgarciae): ideally we just do an update but currently dictionaries don't allow
287287
# insertion of new keys, we need to enable this in NNX to simplify the code below
288288
# to the simple nnx.update(self, nnx_attrs) above.
289+
def _to_nnx_dict(d):
290+
if isinstance(d, dict):
291+
return nnx.Dict({k: _to_nnx_dict(v) for k, v in d.items()})
292+
return d
293+
289294
for attr_name, value in nnx_attrs.items():
290-
if hasattr(self, attr_name) and isinstance(value, dict):
295+
if hasattr(self, attr_name) and isinstance(value, (dict, nnx.Dict)):
291296
original_value = getattr(self, attr_name)
292297
new_values = _recursive_merge(original_value, value)
293-
setattr(self, attr_name, nnx.data(new_values))
298+
setattr(self, attr_name, _to_nnx_dict(new_values))
294299
else:
295-
setattr(self, attr_name, nnx.data(value))
300+
setattr(self, attr_name, _to_nnx_dict(value))
296301

297302
return out
298303

@@ -466,7 +471,9 @@ def maybe_unbox(x):
466471

467472
warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")
468473

469-
nnx.update(module, new_state)
474+
filtered_state_flat = {path: v for path, v in new_state_flat.items() if path in current_state_flat}
475+
filtered_state = nnx.traversals.unflatten_mapping(filtered_state_flat)
476+
nnx.update(module, filtered_state)
470477

471478
_fix_for_qwix_quantization(module)
472479
method_fn = _get_module_method(module, nnx_method)

src/maxtext/layers/quantizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def _get_quant_config(config):
592592
with open(config.quant_cfg_path, "rt", encoding="utf8") as config_file:
593593
mixed_precision_config = json.load(config_file)
594594
return _get_mixed_precision_quant_config(mixed_precision_config)
595-
if config.quantization == "fp8":
595+
if getattr(config.quantization, "name", str(config.quantization)) in ("FP8", "FP8_GPU"):
596596
return "fp8"
597597
if config.quantization == "nanoo_fp8":
598598
return "nanoo_fp8"
@@ -636,7 +636,7 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"):
636636
bwd_calibration_method=config.bwd_quantization_calibration_method,
637637
)
638638

639-
if config.use_qwix_quantization:
639+
if config.use_qwix_quantization and not getattr(config, "enable_nnx", False):
640640
return None
641641
quant_cfg = _get_quant_config(config)
642642
if quant_cfg:

0 commit comments

Comments
 (0)