From 37d547d7d2daf05f931b83ce1ab08ec734b624c6 Mon Sep 17 00:00:00 2001 From: Nicolas Grande Date: Thu, 12 Feb 2026 22:18:22 +0000 Subject: [PATCH] removing duplicate KV cache instance. changing config paths. --- src/MaxText/__init__.py | 1 - .../integration/vllm/maxtext_vllm_adapter/adapter.py | 8 ++++---- src/MaxText/rl/train_rl.py | 6 ++---- src/maxtext/configs/{ => inference}/vllm.yml | 0 src/maxtext/vllm_decode.py | 4 ++-- tools/gcs_benchmarks/standalone_checkpointer.py | 4 ++-- 6 files changed, 10 insertions(+), 13 deletions(-) rename src/maxtext/configs/{ => inference}/vllm.yml (100%) diff --git a/src/MaxText/__init__.py b/src/MaxText/__init__.py index 1eca5831b8..411da60071 100644 --- a/src/MaxText/__init__.py +++ b/src/MaxText/__init__.py @@ -34,7 +34,6 @@ from maxtext.trainers.post_train.dpo import dpo_utils from maxtext.utils import maxtext_utils from maxtext.utils import model_creation_utils -from maxtext.utils.model_creation_utils import from_config Transformer = models.Transformer transformer_as_linen = models.transformer_as_linen diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index ea63a371ae..670ee92d32 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -23,7 +23,7 @@ from jax.sharding import Mesh from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE -from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.globals import MAXTEXT_CONFIGS_DIR from maxtext.utils import max_logging from maxtext.utils import model_creation_utils @@ -73,7 +73,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.") # Add base config path to positional args - base_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml") + base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") argv_list = ["", str(base_config_path)] maxtext_config = pyconfig.initialize(argv_list, **overrides) @@ -151,7 +151,7 @@ def __call__( with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): aux_hidden_states = [] - hidden, updated_kv_caches = self.model( + hidden, kv_caches = self.model( decoder_input_tokens=input_ids, decoder_positions=input_positions, kv_caches=kv_caches, @@ -163,7 +163,7 @@ def __call__( # To be compatible with vLLM, we reshape to (batch * seq, dim). hidden = hidden.reshape((-1, hidden.shape[-1])) - return updated_kv_caches, hidden, aux_hidden_states + return kv_caches, hidden, aux_hidden_states def forward(self, *args, **kwargs): """Alias for __call__ for compatibility. diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 70a76d9ce3..7cb3acdac6 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -73,7 +73,7 @@ os.environ["SKIP_JAX_PRECOMPILE"] = "1" from MaxText import pyconfig -from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.globals import MAXTEXT_CONFIGS_DIR from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter from MaxText.rl.evaluate_rl import evaluate from MaxText.rl import utils_rl @@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.log("Creating policy model with same config as reference model on trainer mesh") actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) - if trainer_config.debug.rl: max_logging.log("Policy Model initialized successfully") nnx.display(actor_model) @@ -495,8 +494,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics." ) - configs_dir = os.environ.get("MAXTEXT_CONFIGS_DIR", os.path.join(MAXTEXT_PKG_DIR, "configs")) - vllm_config_path = epath.Path(configs_dir) / "vllm.yml" + vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") argv_list = ["", str(vllm_config_path), "log_config=False"] vllm_config = pyconfig.initialize(argv_list) diff --git a/src/maxtext/configs/vllm.yml b/src/maxtext/configs/inference/vllm.yml similarity index 100% rename from src/maxtext/configs/vllm.yml rename to src/maxtext/configs/inference/vllm.yml diff --git a/src/maxtext/vllm_decode.py b/src/maxtext/vllm_decode.py index bc27e13e34..2e532d63af 100644 --- a/src/maxtext/vllm_decode.py +++ b/src/maxtext/vllm_decode.py @@ -47,7 +47,7 @@ from maxtext.utils import model_creation_utils from MaxText import pyconfig from MaxText.common_types import Config -from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.globals import MAXTEXT_CONFIGS_DIR from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter from tunix.rl.rollout import base_rollout from tunix.rl.rollout.vllm_rollout import VllmRollout @@ -185,7 +185,7 @@ def decode_with_vllm( f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..." ) - vllm_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml") + vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml") argv_list = ["", str(vllm_config_path), "log_config=False"] vllm_config = pyconfig.initialize(argv_list) diff --git a/tools/gcs_benchmarks/standalone_checkpointer.py b/tools/gcs_benchmarks/standalone_checkpointer.py index 429404d852..42cd69c423 100644 --- a/tools/gcs_benchmarks/standalone_checkpointer.py +++ b/tools/gcs_benchmarks/standalone_checkpointer.py @@ -31,7 +31,6 @@ from flax.linen import partitioning as nn_partitioning -import MaxText as mt from MaxText import pyconfig from MaxText.train import get_first_step from MaxText.layers import models @@ -39,6 +38,7 @@ from maxtext.utils import max_logging from maxtext.utils import maxtext_utils from maxtext.utils import train_utils +from maxtext.utils.model_creation_utils import from_config Transformer = models.transformer_as_linen @@ -52,7 +52,7 @@ def checkpoint_loop(config, state=None): ckpt_path: Returns: """ - model = mt.from_config(config) + model = from_config(config) mesh = model.mesh init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh)