Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/MaxText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from maxtext.trainers.post_train.dpo import dpo_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import model_creation_utils
from maxtext.utils.model_creation_utils import from_config

Transformer = models.Transformer
transformer_as_linen = models.transformer_as_linen
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax.sharding import Mesh
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.globals import MAXTEXT_CONFIGS_DIR
from maxtext.utils import max_logging
from maxtext.utils import model_creation_utils

Expand Down Expand Up @@ -73,7 +73,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")

# Add base config path to positional args
base_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(base_config_path)]

maxtext_config = pyconfig.initialize(argv_list, **overrides)
Expand Down Expand Up @@ -151,7 +151,7 @@ def __call__(

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
aux_hidden_states = []
hidden, updated_kv_caches = self.model(
hidden, kv_caches = self.model(
decoder_input_tokens=input_ids,
decoder_positions=input_positions,
kv_caches=kv_caches,
Expand All @@ -163,7 +163,7 @@ def __call__(
# To be compatible with vLLM, we reshape to (batch * seq, dim).
hidden = hidden.reshape((-1, hidden.shape[-1]))

return updated_kv_caches, hidden, aux_hidden_states
return kv_caches, hidden, aux_hidden_states

def forward(self, *args, **kwargs):
"""Alias for __call__ for compatibility.
Expand Down
6 changes: 2 additions & 4 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
os.environ["SKIP_JAX_PRECOMPILE"] = "1"

from MaxText import pyconfig
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.globals import MAXTEXT_CONFIGS_DIR
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from MaxText.rl.evaluate_rl import evaluate
from MaxText.rl import utils_rl
Expand Down Expand Up @@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)


if trainer_config.debug.rl:
max_logging.log("Policy Model initialized successfully")
nnx.display(actor_model)
Expand Down Expand Up @@ -495,8 +494,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
)

configs_dir = os.environ.get("MAXTEXT_CONFIGS_DIR", os.path.join(MAXTEXT_PKG_DIR, "configs"))
vllm_config_path = epath.Path(configs_dir) / "vllm.yml"
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(vllm_config_path), "log_config=False"]
vllm_config = pyconfig.initialize(argv_list)

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/maxtext/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from maxtext.utils import model_creation_utils
from MaxText import pyconfig
from MaxText.common_types import Config
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.globals import MAXTEXT_CONFIGS_DIR
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from tunix.rl.rollout import base_rollout
from tunix.rl.rollout.vllm_rollout import VllmRollout
Expand Down Expand Up @@ -185,7 +185,7 @@ def decode_with_vllm(
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
)

vllm_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(vllm_config_path), "log_config=False"]
vllm_config = pyconfig.initialize(argv_list)

Expand Down
4 changes: 2 additions & 2 deletions tools/gcs_benchmarks/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@

from flax.linen import partitioning as nn_partitioning

import MaxText as mt
from MaxText import pyconfig
from MaxText.train import get_first_step
from MaxText.layers import models
from maxtext.common import checkpointing
from maxtext.utils import max_logging
from maxtext.utils import maxtext_utils
from maxtext.utils import train_utils
from maxtext.utils.model_creation_utils import from_config

Transformer = models.transformer_as_linen

Expand All @@ -52,7 +52,7 @@ def checkpoint_loop(config, state=None):
ckpt_path:
Returns:
"""
model = mt.from_config(config)
model = from_config(config)
mesh = model.mesh
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh)

Expand Down
Loading