Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or
See our installation guide to [install MaxText with pip from PyPI](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-pypi-recommended).

## Decoupled mode
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html).
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/run_maxtext/decoupled_mode.html).

<!-- NEWS START -->

Expand Down
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ coverage:
patch:
default:
target: auto
threshold: 5% # fail on 5+ percent degradation
threshold: 10% # fail on 10+ percent degradation
flags:
- regular

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ flax
grain>=0.2.12
grpcio>=1.75.1
huggingface_hub>=0.35.3
jax==0.7.1
jaxtyping>=0.3.3
jsonlines>=4.0.0
matplotlib>=3.10.3
Expand All @@ -19,6 +20,7 @@ omegaconf>=2.3.0
optax>=0.2.6
orbax-checkpoint>=0.11.25
pandas>=2.3.3
parameterized==0.9.0
pathwaysutils>=0.1.3
pillow>=11.3.0
protobuf>=5.29.5
Expand All @@ -39,5 +41,4 @@ tiktoken>=0.12.0
tqdm>=4.67.1
transformers>=4.57.0
urllib3>=2.5.0
jax==0.7.1
git+https://github.com/google/tunix.git
git+https://github.com/google/tunix.git
62 changes: 34 additions & 28 deletions docs/run_maxtext/decoupled_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,40 @@
limitations under the License.
-->


# Via Decoupled Mode (No Google Cloud Dependencies)

Set `DECOUPLE_GCLOUD=TRUE` to run MaxText tests and local development without any Google Cloud SDK, `gs://` buckets, JetStream, or Vertex AI integrations.

When enabled:
* Skips external integration tests with markers:
* `external_serving` (`jetstream`, `serving`, `decode_server`)
* `external_training` (`goodput`)
* `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
* Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
* Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
* Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
* Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
* All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_helper.py`. This helper ensures usage of `decoupled_base_test.yml`.

- Skips external integration tests with markers:
- `external_serving` (`jetstream`, `serving`, `decode_server`)
- `external_training` (`goodput`)
- `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
- Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
- Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
- Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
- Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
- All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml`.

Minimal datasets included (checked into the repo):
* ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,

- ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
located in `local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-{train,validation}.array_record-*`
* Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
- Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
located in `local_datasets/c4_en_dataset_minimal/hf/c4`


Run a local smoke test fully offline:

```bash
export DECOUPLE_GCLOUD=TRUE
pytest -k train_gpu_smoke_test -q
```

Optional environment variables:
* `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
* `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.

- `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
- `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.

## Centralized Decoupling API (`gcloud_stub.py`)

Expand All @@ -55,32 +57,36 @@ MaxText exposes a single module `MaxText.gcloud_stub` to avoid scattering enviro
from MaxText.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream

if is_decoupled():
# Skip optional integrations or use local fallbacks
pass
# Skip optional integrations or use local fallbacks
pass

# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration)
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = cloud_diagnostics()
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = (
cloud_diagnostics()
)

# JetStream (serving) components
config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream()
TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object)
```

Behavior when `DECOUPLE_GCLOUD=TRUE`:
* `is_decoupled()` returns True.
* Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
* Prevents import-time failures for optional dependencies (JetStream).

- `is_decoupled()` returns True.
- Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
- Prevents import-time failures for optional dependencies (JetStream).

## Guidelines:
* Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
* Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
* Use `get_test_config_path()` instead of hard-coded `base.yml`.
* Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
* Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
* Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:

- Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
- Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
- Use `get_test_config_path()` instead of hard-coded `base.yml`.
- Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
- Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
- Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:

```
pytest -m decoupled -vv tests
```

This centralized approach keeps optional integrations cleanly separated from core MaxText logic, making local development (e.g. on ROCm/NVIDIA GPUs) frictionless.

22 changes: 9 additions & 13 deletions src/MaxText/configs/decoupled_base_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
# Inherit all model defaults from base.yml but override any cloud-coupled paths and disable optional cloud features.
base_config: base.yml
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
# optional cloud features.

# Output goes to a local relative directory so tests do not require GCS.
base_output_directory: ./maxtext_local_output
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
run_name: test_decoupled

# Disable checkpointing by default for speed unless a test explicitly enables it.
Expand All @@ -23,7 +23,9 @@ profile_periodically_period: 0
profiler_steps: 0

# Leave dataset-related keys to be overridden by individual tests.
dataset_type: ""
dataset_path: "tests/assets/local_datasets/c4_en_dataset_minimal/"
dataset_name: 'c4/en:3.1.0'
eval_dataset_name: 'c4/en:3.1.0'

# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
attention: "dot_product"
Expand All @@ -44,6 +46,8 @@ ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_fsdp_parallelism: 1
ici_fsdp_transpose_parallelism: 1
# Allow higher unsharded parameter percentage for small device count
sharding_tolerance: 0.3

# DCN dimensions to 1 (no multi-slice expectation locally).
dcn_data_parallelism: 1
Expand All @@ -68,12 +72,4 @@ goodput_upload_interval_seconds: 0
enable_pathways_goodput: false
enable_gcp_goodput_metrics: false

# Disable any cloud logging / BigQuery or external metric uploads.
enable_cloud_logging: false
upload_metrics_to_bigquery: false
bigquery_project: ""
bigquery_dataset: ""
bigquery_table: ""

# Force local-only behavior for tests: avoid accidental env pickup.
tensorboard_dir: "./maxtext_local_output/tensorboard"
tensorboard_dir: "./maxtext_local_output/gcloud_decoupled_test_logs/tensorboard"
65 changes: 42 additions & 23 deletions src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Implementation of Engine API for MaxText."""

from collections import defaultdict
from typing import Any, Callable, Union
from typing import Any, Callable
import functools
import os.path
import uuid
Expand All @@ -36,13 +36,6 @@
from flax.linen import partitioning as nn_partitioning
import flax

from jetstream.core import config_lib
from jetstream.engine import engine_api
from jetstream.engine import token_utils
from jetstream.engine import tokenizer_api
from jetstream.engine.tokenizer_pb2 import TokenizerParameters
from jetstream.engine.tokenizer_pb2 import TokenizerType

from MaxText import multimodal_utils
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
Expand All @@ -53,6 +46,11 @@
from maxtext.utils import lora_utils
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.common.gcloud_stub import jetstream, is_decoupled

config_lib, engine_api, token_utils, tokenizer_api, _token_params_ns = jetstream()
TokenizerParameters = getattr(_token_params_ns, "TokenizerParameters", object) # type: ignore[assignment]
TokenizerType = getattr(_token_params_ns, "TokenizerType", object) # type: ignore[assignment]


warnings.simplefilter("ignore", category=FutureWarning)
Expand Down Expand Up @@ -95,14 +93,17 @@ def get_keys(self):
return self.keys


class MaxEngine(engine_api.Engine):
_BaseEngine = engine_api.Engine if (not is_decoupled() and hasattr(engine_api, "Engine")) else object


class MaxEngine(_BaseEngine):
"""The computational core of the generative model server.

Engine defines an API that models must adhere to as they plug into the
JetStream efficient serving infrastructure.
"""

def __init__(self, config: Any, devices: Union[config_lib.Devices, None] = None):
def __init__(self, config: Any, devices: Any | None = None):
self.config = config

# Mesh definition
Expand Down Expand Up @@ -139,7 +140,7 @@ def print_stats(self, label: str):

def generate_aot(
self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None
) -> tuple[DecodeState, engine_api.ResultTokens]:
): # returns (new_decode_state, result_tokens)
"""Wrapper to generate for ahead of time compilation."""

return self.generate(params=params, decode_state=decode_state, rng=rng)
Expand Down Expand Up @@ -393,7 +394,7 @@ def prefill_aot( # pylint: disable=too-many-positional-arguments
padded_tokens: jax.Array,
true_length: int,
rng: PRNGKeyType | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
): # returns (new_prefix, result_tokens)
"""Wrapper for prefill for ahead-of-time compilation."""

return self.prefill(
Expand Down Expand Up @@ -426,7 +427,7 @@ def _prefill_jit(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
): # returns (new_prefix, result_tokens)
"""Performs a JIT-compiled prefill operation on a sequence of tokens.

This function processes an input sequence (prompt) through the model to compute
Expand Down Expand Up @@ -594,7 +595,7 @@ def prefill(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
): # returns (new_prefix, result_tokens)
"""Public API for prefill that updates page state outside JIT."""
# Update page state before JIT call
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
Expand Down Expand Up @@ -643,7 +644,7 @@ def prefill_multisampling_aot( # pylint: disable=too-many-positional-arguments
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
): # returns (new_prefix, result_tokens)
"""Wrapper for multi-sampling prefill for ahead-of-time compilation."""
return self.prefill_multisampling(
params=params,
Expand Down Expand Up @@ -672,7 +673,7 @@ def prefill_multisampling(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
): # returns (new_prefix, result_tokens)
"""Public API for prefill multisampling."""

# Sample rng before JIT call
Expand Down Expand Up @@ -709,7 +710,7 @@ def _prefill_multisampling_jit(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Prefix, engine_api.ResultTokens]:
) -> tuple[Prefix, Any]:
"""Computes a kv-cache for a new generate request.

With multi-sampling, the engine will generate multiple first tokens in the
Expand Down Expand Up @@ -816,7 +817,7 @@ def prefill_concat(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[Any, PackedPrefix, list[engine_api.ResultTokens]]:
): # returns (maybe_batch, packed_prefix, list_of_result_tokens)
"""Computes a kv-cache for a new packed generate request, which is a
concatenation of several shorter prompts. Experimentation shows that
longer prefill sequences gives approximately 15% boost in time per prefilled
Expand Down Expand Up @@ -933,7 +934,7 @@ def generate(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[DecodeState, engine_api.ResultTokens]:
): # returns (decode_state, result_tokens)
"""Public API for generate that updates page state outside JIT."""

# Update page state before JIT call
Expand Down Expand Up @@ -976,7 +977,7 @@ def _generate_jit(
topk: int | None = None,
nucleus_topp: float | None = None,
temperature: float | None = None,
) -> tuple[DecodeState, engine_api.ResultTokens]:
): # returns (decode_state, result_tokens)
"""Performs a single, JIT-compiled autoregressive decoding step.

This function takes the current decoding state, which includes the KV cache
Expand Down Expand Up @@ -1497,8 +1498,19 @@ def get_prefix_destination_sharding(self) -> Any:
"token_logp": self.replicated_sharding,
}

def get_tokenizer(self) -> TokenizerParameters:
"""Return a protobuf of tokenizer info, callable from Py or C++."""
def get_tokenizer(self) -> Any:
"""Return tokenizer parameters; requires JetStream when decoupled.

When DECOUPLE_GCLOUD is FALSE we provide a clear error instead of failing
cryptically on attribute access.
"""
token_params_is_stub = getattr(_token_params_ns, "_IS_STUB", False)
engine_api_is_stub = getattr(engine_api, "_IS_STUB", False)
if is_decoupled() and (token_params_is_stub or engine_api_is_stub):
raise RuntimeError(
"JetStream disabled by DECOUPLE_GCLOUD=TRUE or stubbed; get_tokenizer is unsupported. "
"Unset DECOUPLE_GCLOUD or install JetStream to enable tokenizer functionality."
)
try:
tokenizer_type_val = TokenizerType.DESCRIPTOR.values_by_name[self.config.tokenizer_type].number
return TokenizerParameters(
Expand All @@ -1511,8 +1523,15 @@ def get_tokenizer(self) -> TokenizerParameters:
except KeyError as _:
raise KeyError(f"Unsupported tokenizer type: {self.config.tokenizer_type}") from None

def build_tokenizer(self, metadata: TokenizerParameters) -> tokenizer_api.Tokenizer:
def build_tokenizer(self, metadata: Any): # return type depends on JetStream
"""Return a tokenizer"""
token_params_is_stub = getattr(_token_params_ns, "_IS_STUB", False)
engine_api_is_stub = getattr(engine_api, "_IS_STUB", False)
if is_decoupled() and (token_params_is_stub or engine_api_is_stub):
raise RuntimeError(
"JetStream disabled by DECOUPLE_GCLOUD=TRUE or stubbed; build_tokenizer is unsupported. "
"Unset DECOUPLE_GCLOUD or install JetStream to enable tokenizer functionality."
)
if metadata.tokenizer_type == TokenizerType.tiktoken:
return token_utils.TikToken(metadata)
elif metadata.tokenizer_type == TokenizerType.sentencepiece:
Expand Down
Loading
Loading