diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c18147bc22..3b34007902 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,9 +18,9 @@ src/MaxText/elastic_train.py @lukebaumann @shauryagup @richjames0 @shralex src/MaxText/layers/quantizations.py @khatwanimohit @jshin1394 @liudangyi @richjames0 @shralex # Inference -src/MaxText/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 -src/MaxText/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 -src/MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 +src/maxtext/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 +src/maxtext/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 +src/maxtext/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0 # Dockerfiles and dependencies *.Dockerfile @bvandermoon @parambole @richjames0 @shralex diff --git a/.vscode/launch.json b/.vscode/launch.json index c0d04607f2..3a04c1d2db 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "console": "integratedTerminal", "justMyCode": false, "python": "python3", - "module": "MaxText.decode", + "module": "maxtext.decode", "args": ["src/MaxText/configs/base.yml", "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", @@ -35,9 +35,9 @@ "console": "integratedTerminal", "justMyCode": false, "python": "python3", - "module": "MaxText.decode", + "module": "maxtext.decode", "args": ["src/MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", + "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "dataset_path=gs://test-maxtext-dataset", "steps=2", @@ -53,7 +53,7 @@ "python": "python3", "module": "MaxText.train", "args": ["src/MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", + "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "dataset_path=gs://test-maxtext-dataset", "steps=2", @@ -66,7 +66,7 @@ "console": "integratedTerminal", "justMyCode": false, "python": "python3", - "module": "MaxText.inference_microbenchmark", + "module": "maxtext.inference_microbenchmark", "args": [ "src/MaxText/configs/base.yml", "model_name=llama2-7b", @@ -82,7 +82,7 @@ "inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024", "inference_microbenchmark_stages=generate", "inference_microbenchmark_loop_iters=1", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", + "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "prefill_cache_axis_order=0,2,1,3", "ar_cache_axis_order=0,2,1,3", diff --git a/codecov.yml b/codecov.yml index cf25ff4c9c..d8fd9822c5 100644 --- a/codecov.yml +++ b/codecov.yml @@ -24,7 +24,7 @@ # During scheduled runs, the 'regular' flag is carried forward from the last PR. # Exclude non-source code, deprecated and experimental folders from coverage tracking -codecov: +codecov: token: 35742a22-fb1f-4839-97ff-b54da5588689 # By default file names in the coverage report will have their path in the file system, which in our # runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path @@ -36,8 +36,8 @@ ignore: - "src/MaxText/configs" - "src/MaxText/examples" - "src/MaxText/experimental" - - "src/MaxText/inference" - - "src/MaxText/inference_mlperf" + - "src/maxtext/inference" + - "src/maxtext/inference_mlperf" - "src/MaxText/scratch_code" - "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation - "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft diff --git a/docs/guides/optimization/pallas_kernels_performance.md b/docs/guides/optimization/pallas_kernels_performance.md index a6536f429d..a7fa0e6370 100644 --- a/docs/guides/optimization/pallas_kernels_performance.md +++ b/docs/guides/optimization/pallas_kernels_performance.md @@ -26,8 +26,8 @@ This guide explains **when** to consider Pallas, a **workflow** for developing a Think in **roofline** terms ([All About Rooflines](https://jax-ml.github.io/scaling-book/roofline/)) and in terms of **structure the compiler can’t see**: -* **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling. -* **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help. +- **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling. +- **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help. **Know when XLA is enough.** Before writing a custom kernel, always [profile your baseline](#1-high-level-profiling). If a standard operation (like a dense [`jnp.matmul`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html)) is already performing well, the XLA compiler is doing its job. In these cases, a Pallas kernel will increase code complexity and maintenance burden with minimal performance improvement. @@ -42,29 +42,34 @@ it is very difficult to automatically infer the dual of the memory pipeline. For dense, regular GEMMs, XLA’s libraries are hard to beat. The exception is **Mixture-of-Experts (MoE)** MLPs with **ragged token→expert layouts** (some tokens routed to different experts; shapes are irregular). Zero-padding to make dense tensors wastes FLOPs; a custom kernel can operate only on the actually-selected tokens. -* In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions. +- In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions. **Note:** *Megablox* is an efficient, non-capped MoE implementation in JAX. *Megablocks* refers to the equivalent PyTorch implementation. See [arXiv:2211.15841](https://arxiv.org/abs/2211.15841) for more details. ### 2. Memory-Access-Bound work (attention) -Attention kernels are classically **bandwidth-limited** if you materialize the full \[L,L\] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate. +Attention kernels are classically **bandwidth-limited** if you materialize the full [L,L] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate. -* MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts. +- MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts. ## 🛠️ Pallas kernels in MaxText To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth-bound or structurally irregular operations that a general-purpose compiler cannot optimize as effectively. Below are the key kernels we use. **Note**: Examples evolve; treat this list as guidance. -* **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large \[L,L\] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation. - * [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py) -* **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine. - * [`src/MaxText/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention.py) - * [`src/MaxText/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention_kernel_v2.py) -* **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata. +- **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation. - > This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts. - * [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py) + - [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py) + +- **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine. + + - [`src/maxtext/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention.py) + - [`src/maxtext/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention_kernel_v2.py) + +- **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata. + + > This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts. + + - [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py) **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)). @@ -74,7 +79,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth Give the kernel a clear name in traces and capture a profile. Always use [`jax.block_until_ready()`](https://docs.jax.dev/en/latest/_autosummary/jax.block_until_ready.html) when timing your operations. -``` python +```python import jax from jax import profiler @@ -104,12 +109,12 @@ For a more automated approach, consider using libraries like [tune-jax](https:// Pallas exposes the underlying hardware primitives for you to control. -* **HBM:** High-Bandwidth Memory (standard device memory). -* **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs. -* **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables). -* **Semaphores:** Available for advanced async/barrier patterns in manual pipelines. -* **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions. -* **VPU:** The Vector Processing Unit, used for elementwise/vector work. +- **HBM:** High-Bandwidth Memory (standard device memory). +- **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs. +- **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables). +- **Semaphores:** Available for advanced async/barrier patterns in manual pipelines. +- **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions. +- **VPU:** The Vector Processing Unit, used for elementwise/vector work. **Alignment & Constraints:** Respect TPU BlockSpec constraints (divisibility/shape rules for trailing dimensions and supported block shapes). Start with tile shapes that fit in VMEM and meet these requirements, then sweep different sizes to find the optimum. Let profiling guide you; don't assume powers of two are always best. @@ -117,11 +122,11 @@ Pallas exposes the underlying hardware primitives for you to control. These are the common techniques used in MaxText's Pallas kernels. -* **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back. -* **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering). -* **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays. -* **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory. -* **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible. +- **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back. +- **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering). +- **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays. +- **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory. +- **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible. ## ✍️ Writing & integrating a kernel @@ -136,9 +141,11 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl + def add_vectors_kernel(x_ref, y_ref, o_ref): o_ref[:] = x_ref[:] + y_ref[:] + def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: assert x.shape == y.shape return pl.pallas_call( @@ -156,14 +163,16 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl + def tile_add_kernel(x_ref, y_ref, o_ref): # Operate on the tile slices handed in by BlockSpecs (already in VMEM on TPU). o_ref[:, :] = x_ref[:, :] + y_ref[:, :] + def tile_add(x: jax.Array, y: jax.Array) -> jax.Array: assert x.shape == y.shape and x.ndim == 2 B0 = min(128, x.shape[0]) # Example choice; tune this with a sweep - B1 = x.shape[1] # Full width tile (for illustration) + B1 = x.shape[1] # Full width tile (for illustration) # Map program id (tile index) -> tile origin in the full (HBM) array. # NOTE: The runtime advances origins by `block_shape`, so `i` is already a tile @@ -192,29 +201,28 @@ def tile_add(x: jax.Array, y: jax.Array) -> jax.Array: Prefer `pl.pallas_call` with scratch buffers allocated in the appropriate memory space (VMEM/SMEM) and use multi-buffering to overlap HBM loads with compute. Advanced pipelining to consider: custom prefetch block order via a scalar prefetch grid (for details see [here](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html)), which lets you control block execution order based on runtime values. - ## 🌐 Distributed execution Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpler and more maintainable than in-kernel cross-device communication. While Pallas supports low-level comms, `shard_map` is the right first choice for multi-device parallelism, and you can **communicate with `shard_map` collectives** when needed. ## 🐞 Debugging tips -* Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA. -* Start with a tiny problem size and assert on invariants inside the kernel. -* Add `jax.named_scope` liberally so kernels are easy to spot in performance traces. +- Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA. +- Start with a tiny problem size and assert on invariants inside the kernel. +- Add `jax.named_scope` liberally so kernels are easy to spot in performance traces. ## ✅ Putting it all together (checklist) 1. **Profile** the baseline using `named_scope` and `block_until_ready`. -2. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels). -3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices). -4. **Validate** end-to-end performance in the model, not just microbenchmarks. -5. Consider **maintainability** and guard the new kernel with tests. -6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically. +1. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels). +1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices). +1. **Validate** end-to-end performance in the model, not just microbenchmarks. +1. Consider **maintainability** and guard the new kernel with tests. +1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically. ## 📚 References -* **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html) -* **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html) -* **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html) -* **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) +- **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html) +- **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html) +- **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html) +- **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) diff --git a/docs/run_maxtext/run_maxtext_localhost.md b/docs/run_maxtext/run_maxtext_localhost.md index 2e095c8838..66bc917c04 100644 --- a/docs/run_maxtext/run_maxtext_localhost.md +++ b/docs/run_maxtext/run_maxtext_localhost.md @@ -1,42 +1,49 @@ # Via localhost or single-host VM ## Objective + This guide provides comprehensive instructions for setting up MaxText on a local machine or single-host environment, covering everything from cloning the repo and dependency installation to building with Docker. By walking through the process of pre-training a small model, you will gain the foundational knowledge to run jobs on TPUs/GPUs. ## Prerequisites + Before you can begin a training run, you need to configure your storage environment and set up the basic MaxText configuration. ### Setup Google Cloud storage bucket + You'll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints. -1. In your Google Cloud project, create a new storage bucket. -2. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs. +1. In your Google Cloud project, create a new storage bucket. +1. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs. ### Setup MaxText + MaxText uses a primary YAML file, `configs/base.yml`, to manage its settings. This default configuration sets up a llama2 style decoder-only model with approximately 1 billion parameters. -* Before running your first model, take a moment to review this file. Pay special attention to these core settings: +- Before running your first model, take a moment to review this file. Pay special attention to these core settings: - `run_name`: The name for your experiment. - `per_device_batch_size`: Controls how many examples are processed per chip. You may need to lower this for larger models to avoid running out of memory. - `max_target_length`: The maximum sequence length for the model. - `learning_rate`: The core hyperparameter for the optimizer. - Mode shape parameters: `base_num_decoder_layers`, `base_emb_dim`, `base_num_query_heads`, `base_num_kv_heads`, and `head_dim`. -* **Override settings (optional):** You can modify training parameters in two ways: by editing `configs/base.yml` directly or by passing them as command-line arguments to the training script which is the recommended method. For example, to change the number of training steps, you can pass `--steps=500` when running `train.py`. -* **Note**: You **must** update the variable `base_output_directory` which is initialized in `configs/base.yml` to point to a folder within the GCS bucket you just created (e.g., `gs://your-bucket-name/maxtext-output`). +- **Override settings (optional):** You can modify training parameters in two ways: by editing `configs/base.yml` directly or by passing them as command-line arguments to the training script which is the recommended method. For example, to change the number of training steps, you can pass `--steps=500` when running `train.py`. +- **Note**: You **must** update the variable `base_output_directory` which is initialized in `configs/base.yml` to point to a folder within the GCS bucket you just created (e.g., `gs://your-bucket-name/maxtext-output`). ## Development + Local development on a single host TPU/GPU VM is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts but is a good way to learn about MaxText. The following describes how to run Maxtext on TPU/GPU VMs. ### Run MaxText on single host VM -1. Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. For GPUs, you can use `nvidia-h100-mega-80gb`, `nvidia-h200-141gb`, or `nvidia-b200`. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus. -2. Clone MaxText onto that VM. - ```bash - git clone https://github.com/google/maxtext.git - cd maxtext - ``` +1. Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. For GPUs, you can use `nvidia-h100-mega-80gb`, `nvidia-h200-141gb`, or `nvidia-b200`. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus. -3. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach. +1. Clone MaxText onto that VM. + + ```bash + git clone https://github.com/google/maxtext.git + cd maxtext + ``` + +1. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach. Within the root directory of the cloned repo, create a virtual environment and install dependencies and the pre-commit hook by running: @@ -47,6 +54,7 @@ bash tools/setup/setup.sh DEVICE={tpu|gpu} ``` #### Run a Test Training Job + After the installation is complete, run a short training job using synthetic data to confirm everything is working correctly. This command trains a model for just 10 steps. Remember to replace `$YOUR_JOB_NAME` with a unique name for your run and `gs://` with the path to the GCS bucket you configured in the prerequisites. ```bash @@ -64,7 +72,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ To demonstrate model output, run the following command: ```bash -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ per_device_batch_size=1 @@ -73,9 +81,11 @@ python3 -m MaxText.decode src/MaxText/configs/base.yml \ **Note:** Because the model hasn't been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the `load_parameters_path` argument. ### Running models using provided configs + MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/MaxText/configs/models` for TPU-oriented defaults, and `src/MaxText/configs/models/gpu` for GPU-oriented defaults. #### Training on TPUs + To use a pre-configured model for TPUs, you override the `model_name` parameter, and MaxText will automatically load the corresponding configuration from the `src/MaxText/configs/models` directory and merge it with the settings from `src/MaxText/configs/base.yml`.
@@ -89,6 +99,7 @@ python3 -m MaxText.train MaxText/configs/base.yml \ dataset_type=synthetic \ steps=10 ``` +
@@ -102,9 +113,11 @@ python3 -m MaxText.train MaxText/configs/base.yml \ dataset_type=synthetic \ steps=10 ``` +
#### Training on GPUs + To use a GPU-optimized configuration, you should specify the path to the model's YAML file within the `src/MaxText/configs/models/gpu` directory as the main config file in the command. These files typically inherit from `base.yml` and set the appropriate `model_name` internally, as well as GPU-specific settings.
@@ -117,7 +130,9 @@ python3 -m MaxText.train src/MaxText/configs/models/gpu/mixtral_8x7b.yml \ dataset_type=synthetic \ steps=10 ``` + This will load `gpu/mixtral_8x7b.yml`, which inherits from `base.yml`. +
@@ -130,5 +145,5 @@ python3 -m MaxText.train src/MaxText/configs/models/gpu/llama3-8b.yml \ dataset_type=synthetic \ steps=10 ``` -
+ diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 9e16f034fa..7fefa80344 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -15,32 +15,39 @@ --> (first-run)= + # Getting started: First run This topic provides a basic introduction to get your MaxText workload up and running on single host and multihost environments using Cloud TPUs or NVIDIA GPUs. To help you get familiar with MaxText, we recommend starting with a single host first and then moving to multihost. ## Prerequisites: Set up storage and configure MaxText + 1. To store logs and checkpoints, [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) in your project. To run MaxText, the TPU or GPU VMs must have read/write permissions for the bucket. These permissions are granted by service account roles, such as the `STORAGE ADMIN` role. -2. MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in `configs/base.yml`. This file includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. Set `base_output_directory` to a folder in the bucket you just created. +1. MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in `configs/base.yml`. This file includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. Set `base_output_directory` to a folder in the bucket you just created. ## Local development for single host + This procedure describes how to run MaxText on a single GPU or TPU host. ### Run MaxText on cloud TPUs + Local development is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts but is a good way to learn about MaxText. 1. [Create and SSH to the single host VM of your choice](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm). You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. -2. Clone MaxText onto that TPU VM. -3. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: +1. Clone MaxText onto that TPU VM. +1. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: + ```sh python3 -m venv ~/venv-maxtext source ~/venv-maxtext/bin/activate bash tools/setup/setup.sh pre-commit install ``` + 4. After installation completes, run training on synthetic data with the following command: + ```sh python3 -m MaxText.train src/MaxText/configs/base.yml \ run_name=$YOUR_JOB_NAME \ @@ -48,44 +55,52 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ dataset_type=synthetic \ steps=10 ``` + Optional: If you want to try training on a Hugging Face dataset, see [Data Input Pipeline](../guides/data_input_pipeline.md) for data input options. 5. To demonstrate model output, run the following command: + ```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ per_device_batch_size=1 ``` -This command uses a model with randomly initialized weights, so the outputs are also random. To get high quality output you need pass in a checkpoint, typically via the `load_parameters_path` argument. +This command uses a model with randomly initialized weights, so the outputs are also random. To get high quality output you need pass in a checkpoint, typically via the `load_parameters_path` argument. ### Run MaxText via notebook -In the same TPU VM where you just installed all the dependencies of MaxText, You can also run training and decoding in MaxText via Notebook (for e.g., via Jupyter or Colab). + +In the same TPU VM where you just installed all the dependencies of MaxText, You can also run training and decoding in MaxText via Notebook (for e.g., via Jupyter or Colab). #### Decoding in MaxText via notebook + You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. ### Run MaxText on NVIDIA GPUs + 1. Use `bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu` to build a container with the required dependencies. -2. After installation is complete, run training with the following command on synthetic data: +1. After installation is complete, run training with the following command on synthetic data: + ```sh python3 -m MaxText.train src/MaxText/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ - steps=10 + steps=10 ``` -3. To demonstrate model output, run the following command: +3. To demonstrate model output, run the following command: + ```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ - per_device_batch_size=1 + per_device_batch_size=1 ``` If you see the following error when running inside a container, set a larger `--shm-size` (for example, `--shm-size=1g`): + ``` Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.all_reduce' failed: external/xla/xla/service/gpu/nccl_utils.cc:297: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: unhandled cuda error (run with NCCL_DEBUG=INFO for details); current tracing scope: all-reduce-start.2; current profiling annotation: XlaModule:#hlo_module=jit__unnamed_wrapped_function_,program_id=7#. ``` diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index e845bd1ffe..999e4d8970 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -1,20 +1,21 @@ - - # Multimodal support This document provides a guide to use the multimodal functionalities in MaxText including: + - **Checkpoint Conversion**: Convert a MaxText-compatible orbax checkpoint from HuggingFace. - **Multimodal Decode**: Inference with text+images as input. - **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset. We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support: -| Models | Input Modalities | Output Modalities | -| :---- | :---- | :---- | -| - Gemma3-4B/12B/27B
- Llama4-Scout/Maverick | Text, images | Text | + +| Models | Input Modalities | Output Modalities | +| :--------------------------------------------- | :--------------- | :---------------- | +| - Gemma3-4B/12B/27B
- Llama4-Scout/Maverick | Text, images | Text | ## Introduction -Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline: +Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline: + - **Data Preprocessing**: We apply modality-specific preprocessing steps to prepare the raw input data (e.g., image resizing and normalization), transforming them into a format which neural networks can understand. - **Modality-Specific Encoders**: Modality-specific encoders will transform the preprocessed data into high-dimensional representations (e.g., vision transformers for images). - **Projection and Merge**: Projection layers will map these modality-specific embeddings into the shared embedding space of the language model, usually aligned with the dimension of text embeddings. These projected embeddings are then merged with text token embeddings, allowing the unified model to process and reason over multiple modalities simultaneously within a single coherent framework. @@ -22,12 +23,12 @@ Multimodal Large Language Models (LLMs) extend traditional text-only models by i ![Illustration of multimodal MaxText.](../../_static/multimodal_overview.png) *Figure 1: Overview of multimodal dataflow in MaxText.* - ## Checkpoint Conversion Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md)). Install pytorch: + ``` python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu ``` @@ -58,7 +59,9 @@ python -m MaxText.utils.ckpt_scripts.llama4_ckpt_unscanned \ ``` ## Multimodal Decode + MaxText supports multimodal decoding, allowing you to input text with multiple images to get a text output. To use this feature, you need three main settings: + - `use_multimodal=True`: Initializes the multimodal preprocessing steps and network components. - `prompt`: Specifies the position of image placeholder tokens in your input. If you don't manually place them, MaxText will automatically append the required placeholder (e.g., `` for Gemma3, `<|image|>` for Llama4). The exact placeholder is listed under the `image_placeholder` field in each model's configuration file. - `image_path`: The path(s) to the image file(s) MaxText will load and process. @@ -69,7 +72,7 @@ To run a forward pass and verify the model's output, use the following command: ```shell # Gemma3 decode -python -m MaxText.decode \ +python -m maxtext.decode \ MaxText/configs/base.yml \ model_name=gemma3-4b \ hf_access_token=$HF_ACCESS_TOKEN \ @@ -89,6 +92,7 @@ python -m MaxText.decode \ ``` The decoding results will look like this: + ``` Input `user Describe image @@ -104,7 +108,7 @@ To decode with multiple images at once, you can provide multiple image paths lik export TARGET_LENGTH=... # Adjust to fit expected output length export PREDICT_LENGTH=... # Adjust to fit image tokens + text prompt -python -m MaxText.decode \ +python -m maxtext.decode \ MaxText/configs/base.yml \ model_name=gemma3-4b \ ... \ @@ -123,7 +127,6 @@ Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality: - ```shell export UNSCANNED_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step python -m MaxText.sft_trainer \ @@ -148,14 +151,16 @@ python -m MaxText.sft_trainer \ ``` ## Other Recommendations + - **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules: - - For text tokens, a good estimate is: - - $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$. - - For Gemma3, each image is resized to 896*896 and contributes 256 tokens: - - $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$. - - For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens: - - $\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$. + - For text tokens, a good estimate is: + + $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$. + + - For Gemma3, each image is resized to 896\*896 and contributes 256 tokens: + + $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$. + + - For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens: + $\text{Total Tokens} \approx \text{Text Tokens} + 144 \times \sum_{i=1}^{N} \text{Number of Tiles of Image}_i$. diff --git a/end_to_end/gpu/a3/test_llama2_7b.sh b/end_to_end/gpu/a3/test_llama2_7b.sh index 2f8fe9af88..449bd6b234 100644 --- a/end_to_end/gpu/a3/test_llama2_7b.sh +++ b/end_to_end/gpu/a3/test_llama2_7b.sh @@ -64,4 +64,4 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 export TF_FORCE_GPU_ALLOW_GROWTH=true -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} diff --git a/end_to_end/gpu/mixtral/test_8x7b.sh b/end_to_end/gpu/mixtral/test_8x7b.sh index ece8f5f600..2a35e0601d 100644 --- a/end_to_end/gpu/mixtral/test_8x7b.sh +++ b/end_to_end/gpu/mixtral/test_8x7b.sh @@ -8,7 +8,7 @@ if [ -z "${BASE_OUTPUT_PATH}" ]; then echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" fi -# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py` +# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py` if [ -z "${SCANNED_CHECKPOINT}" ]; then # Non-Googlers please remember to point SCANNED_CHECKPOINT to GCS buckets that you own export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/8x7/scanned_ckpt/0/items @@ -49,7 +49,7 @@ echo "Finished fine-tuning" # # TODO(b/391864113): Add this once the bug is fixed # # Run decoding with converted ckpt - dropping implementation -# python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=mixtral-8x7b hardware=gpu \ +# python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=mixtral-8x7b hardware=gpu \ # run_name=unscanned_decoding load_parameters_path=${UNSCANNED_CKPT_PATH} \ # async_checkpointing=false attention=dot_product capacity_factor=0.1 \ # ici_expert_parallelism=8 ici_fsdp_parallelism=1 max_prefill_predict_length=11 \ diff --git a/end_to_end/test_generate_param_only_checkpoint.sh b/end_to_end/test_generate_param_only_checkpoint.sh index 6dabc3a990..4c7b9381ed 100644 --- a/end_to_end/test_generate_param_only_checkpoint.sh +++ b/end_to_end/test_generate_param_only_checkpoint.sh @@ -104,7 +104,7 @@ fi echo echo "Run decode using the generated checkpoint" echo -$cmd python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +$cmd python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ run_name=${run_id}-decode-steps-50 \ base_output_directory=${base_output_directory} \ dataset_path=${dataset_path} \ diff --git a/end_to_end/tpu/deepseek/Run_DeepSeek.md b/end_to_end/tpu/deepseek/Run_DeepSeek.md index 122a5b3c5a..63e5a94978 100644 --- a/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -18,9 +18,9 @@ DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V3.1 (671B), DeepSeek V3 (671B), DeepSeek R1 (671B), and DeepSeek V2-Lite (16B). -* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance. +* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance. -* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. +* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. * DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning. @@ -63,7 +63,7 @@ To get started, follow the instructions at HuggingFace ([V3](https://huggingface ## Fine-tuning -After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets. +After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets. One example command to run general finetuning with V3 on v5p-256. @@ -140,7 +140,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ One example command to run decoding with V3 on v5p-256 with unscanned checkpoint for fast decoding. ```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=decode \ diff --git a/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index cccebe3f23..989952466e 100644 --- a/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -3,7 +3,7 @@ # This file is documentation for how to get started with DeepSeek v2-Lite on v5p-8. # The flow of this file is as follows: -# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): # Scanned format is better for training; unscanned format is better for decoding. # 2. Run logit check, pre-training, fine-tuning, and decoding. @@ -34,7 +34,7 @@ echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} # Step 1: Checkpoint conversion # You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite, and dequantize it to bf16 -# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET +# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET # Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own # Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory if [ -z "${CKPT_DISK_LOCATION}" ]; then @@ -43,13 +43,13 @@ if [ -z "${CKPT_DISK_LOCATION}" ]; then export CKPT_DISK_LOCATION=/tmp/hf fi -# 1.1 Convert checkpoint to `scanned` format, more suitable for training +# 1.1 Convert checkpoint to `scanned` format, more suitable for training JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/scanned --model_size ${MODEL_NAME} # 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_unscanned_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/unscanned --model_size ${MODEL_NAME} -# Step 2: +# Step 2: # We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items @@ -75,4 +75,4 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # Run decoding - tokamax_gmm implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m MaxText.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=4 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " +python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=4 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " diff --git a/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index e70b5fb792..e59c7be5b3 100644 --- a/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -3,12 +3,12 @@ # This file is documentation for how to get started with DeepSeek v3. # This file runs Step 2 on v5p-128 on a daily basis. -# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): # Scanned format is better for training; unscanned format is better for decoding. # 2. Run logit check, pre-training, fine-tuning, and decoding. # The golden logit can be generated by: -# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16 +# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16 set -ex @@ -30,7 +30,7 @@ fi BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} -# Step 2: +# Step 2: # We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands # export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items # export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items @@ -53,4 +53,4 @@ python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/bas # Run decoding - tokamax_gmm implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m MaxText.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " +python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " diff --git a/end_to_end/tpu/gemma/2b/test_gemma.sh b/end_to_end/tpu/gemma/2b/test_gemma.sh index a8b68795ab..811f38b604 100644 --- a/end_to_end/tpu/gemma/2b/test_gemma.sh +++ b/end_to_end/tpu/gemma/2b/test_gemma.sh @@ -1,6 +1,6 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. +# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. # The flow of this file is as follows: # 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText @@ -37,12 +37,12 @@ python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} @@ -51,14 +51,14 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt_${idx} python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma-2b python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.015 diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh index 7979a28af6..edcd091cd5 100644 --- a/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. -# Please make sure you have run end_to_end/tpu/gemma/7b/1_test_gemma.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. +# Please make sure you have run end_to_end/tpu/gemma/7b/1_test_gemma.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/tpu/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/gemma/7b/1_test_gemma.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma/7b/1_test_gemma.sh -# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -37,12 +37,12 @@ export RUN_NAME=unscanned_chkpt export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items export ASYNC_CHECKPOINTING=True # True so that the jax distributed system is initialized -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune @@ -51,14 +51,14 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. diff --git a/end_to_end/tpu/gemma2/27b/2_test_gemma.sh b/end_to_end/tpu/gemma2/27b/2_test_gemma.sh index 9f9d6a1ba5..b9dad40208 100644 --- a/end_to_end/tpu/gemma2/27b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma2/27b/2_test_gemma.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma2-27b. -# Please make sure you have run end_to_end/tpu/gemma2/27b/1_test_gemma.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma2-27b. +# Please make sure you have run end_to_end/tpu/gemma2/27b/1_test_gemma.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Gemma2 27b with the converted checkpoint obtained from end_to_end/tpu/gemma2/27b/1_test_gemma.sh. Also, run pretraining of Gemma2 27b @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/gemma2/27b/1_test_gemma.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma2/27b/1_test_gemma.sh -# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -37,12 +37,12 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_gemma.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-27b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` diff --git a/end_to_end/tpu/gemma2/2b/test_gemma2.sh b/end_to_end/tpu/gemma2/2b/test_gemma2.sh index ba4e45530b..c4ceb5b1d1 100644 --- a/end_to_end/tpu/gemma2/2b/test_gemma2.sh +++ b/end_to_end/tpu/gemma2/2b/test_gemma2.sh @@ -41,10 +41,10 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} @@ -60,7 +60,7 @@ export PARAM_RUN_NAME=param_chkpt_${idx} python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma2-2b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-2b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` diff --git a/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh b/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh index d7f9f521c3..ba807727d4 100644 --- a/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh +++ b/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh @@ -59,7 +59,7 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ --run_hf_model=true # We can run decoding for unscanned checkpoints. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset @@ -72,4 +72,4 @@ export FINETUNE_RUN_NAME=runner_finetune_${idx} python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${SCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=decode_test_${FINETUNE_RUN_NAME} max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true prompt='I love to' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=decode_test_${FINETUNE_RUN_NAME} max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true prompt='I love to' attention=\'dot_product\' diff --git a/end_to_end/tpu/gemma2/9b/2_test_gemma.sh b/end_to_end/tpu/gemma2/9b/2_test_gemma.sh index dfd2c54b50..2049335418 100644 --- a/end_to_end/tpu/gemma2/9b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma2/9b/2_test_gemma.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma2-9b. -# Please make sure you have run end_to_end/tpu/gemma2/9b/1_test_gemma.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma2-9b. +# Please make sure you have run end_to_end/tpu/gemma2/9b/1_test_gemma.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Gemma2 9b with the converted checkpoint obtained from end_to_end/tpu/gemma2/9b/1_test_gemma.sh. Also, run pretraining of Gemma2 9b @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/gemma2/9b/1_test_gemma.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma2/9b/1_test_gemma.sh -# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -38,12 +38,12 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_gemma.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-9b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` diff --git a/end_to_end/tpu/gemma3/12b/test_gemma3.sh b/end_to_end/tpu/gemma3/12b/test_gemma3.sh index 10a4e7372e..388909a1b1 100644 --- a/end_to_end/tpu/gemma3/12b/test_gemma3.sh +++ b/end_to_end/tpu/gemma3/12b/test_gemma3.sh @@ -43,7 +43,7 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 diff --git a/end_to_end/tpu/gemma3/27b/test_gemma3.sh b/end_to_end/tpu/gemma3/27b/test_gemma3.sh index f3ddf8e74a..5dc2d2cb61 100644 --- a/end_to_end/tpu/gemma3/27b/test_gemma3.sh +++ b/end_to_end/tpu/gemma3/27b/test_gemma3.sh @@ -43,7 +43,7 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3.sh b/end_to_end/tpu/gemma3/4b/test_gemma3.sh index f8da8ce5da..f9a7204c99 100644 --- a/end_to_end/tpu/gemma3/4b/test_gemma3.sh +++ b/end_to_end/tpu/gemma3/4b/test_gemma3.sh @@ -43,7 +43,7 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh b/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh index 7f6a33faa5..eb97eae60e 100644 --- a/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh +++ b/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh @@ -40,7 +40,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEX # 2. Decode the converted checkpoint to make sure it works export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' # 3. SFT the MaxText converted checkpoint on ChartQA dataset export BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/sft @@ -61,7 +61,7 @@ python -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src # 4. Decode from the finetuned checkpoint from step 3 export FINAL_CKPT_STEP=$((SFT_STEPS - 1)) export FINETUNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${idx}/checkpoints/${FINAL_CKPT_STEP}/items -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' # 5. Convert the SFT checkpoint back to HuggingFace format. export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} diff --git a/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh b/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh index cdc570a745..0c122a0581 100644 --- a/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh +++ b/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh @@ -67,9 +67,9 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ # We can run decoding for unscanned checkpoints. if [ ${USE_MULTIMODAL} == true ]; then - python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' else - python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' fi # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data @@ -84,7 +84,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # Now, run decoding on the checkpoint generated from our finetune run. if [ ${USE_MULTIMODAL} == true ]; then - python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' else - python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' fi diff --git a/end_to_end/tpu/gemma3/Run_Gemma3.md b/end_to_end/tpu/gemma3/Run_Gemma3.md index 8265afd07f..e1d0f6358a 100644 --- a/end_to_end/tpu/gemma3/Run_Gemma3.md +++ b/end_to_end/tpu/gemma3/Run_Gemma3.md @@ -16,9 +16,9 @@ # Gemma3 -[Gemma3](https://ai.google.dev/gemma) is an iteration of the Gemma family, designed for enhanced performance and efficiency which is capable of running on a single-accelerator ([Developer Blog](https://blog.google/technology/developers/gemma-3/)). +[Gemma3](https://ai.google.dev/gemma) is an iteration of the Gemma family, designed for enhanced performance and efficiency which is capable of running on a single-accelerator ([Developer Blog](https://blog.google/technology/developers/gemma-3/)). -We provide examples for checkpoint conversion and decoding/training/finetuning Gemma3 in test scripts at [end_to_end/tpu/gemma3](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma3). +We provide examples for checkpoint conversion and decoding/training/finetuning Gemma3 in test scripts at [end_to_end/tpu/gemma3](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma3). ## Pre-training @@ -42,5 +42,5 @@ python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_ One example to use a converted checkpoint to decode with prompt "I love to": ``` -python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to" +python3 -m maxtext.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to" ``` \ No newline at end of file diff --git a/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh b/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh index 5d59aea1aa..751d2e7865 100644 --- a/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh +++ b/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh @@ -3,7 +3,7 @@ # This file is documentation for how to get started with gpt-oss-120b on v5p-64. # The flow of this file is as follows: -# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16), on a separate CPU: +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16), on a separate CPU: # Scanned format is better for training; unscanned format is better for decoding. # 2. Run logit check, pre-training, fine-tuning, and decoding. @@ -27,7 +27,7 @@ fi python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Step 1: Checkpoint conversion -# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET +# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET # Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own # Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory if [ -z "${CKPT_DISK_LOCATION}" ]; then @@ -36,13 +36,13 @@ if [ -z "${CKPT_DISK_LOCATION}" ]; then export CKPT_DISK_LOCATION=/tmp/hf-bf16 fi -# 1.1 Convert checkpoint to `scanned` format, more suitable for training +# 1.1 Convert checkpoint to `scanned` format, more suitable for training JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/scanned --model-size ${MODEL_NAME} # 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_NAME} -# Step 2: +# Step 2: # We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items @@ -64,4 +64,4 @@ python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/sr # Run decoding - megablox implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=32 ici_tensor_parallelism=1 +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=32 ici_tensor_parallelism=1 diff --git a/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh b/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh index 333af655e4..b71e19c415 100644 --- a/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh +++ b/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh @@ -3,7 +3,7 @@ # This file is documentation for how to get started with gpt-oss-20b on v5p-8. # The flow of this file is as follows: -# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): # Scanned format is better for training; unscanned format is better for decoding. # 2. Run logit check, pre-training, fine-tuning, and decoding. @@ -29,7 +29,7 @@ echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Step 1: Checkpoint conversion -# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET +# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET # Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own # Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory if [ -z "${CKPT_DISK_LOCATION}" ]; then @@ -38,13 +38,13 @@ if [ -z "${CKPT_DISK_LOCATION}" ]; then export CKPT_DISK_LOCATION=/tmp/hf-bf16 fi -# 1.1 Convert checkpoint to `scanned` format, more suitable for training +# 1.1 Convert checkpoint to `scanned` format, more suitable for training JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/scanned --model-size ${MODEL_NAME} # 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_NAME} -# Step 2: +# Step 2: # We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items @@ -68,4 +68,4 @@ python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/sr # Run decoding - megablox implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=1 ici_tensor_parallelism=4 +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=1 ici_tensor_parallelism=4 diff --git a/end_to_end/tpu/gpt_oss/run_gpt_oss.md b/end_to_end/tpu/gpt_oss/run_gpt_oss.md index bd5ceeb0b1..039c2eee88 100644 --- a/end_to_end/tpu/gpt_oss/run_gpt_oss.md +++ b/end_to_end/tpu/gpt_oss/run_gpt_oss.md @@ -39,7 +39,7 @@ python3 -m MaxText.utils.ckpt_scripts.dequantize_mxfp4 --input-path= \ @@ -79,7 +79,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ ## Finetuning -After you have a MaxText-compatible scanned checkpoint, you could finetune it with different datasets. +After you have a MaxText-compatible scanned checkpoint, you could finetune it with different datasets. One example command to run general finetuning with gpt-oss-20b on v5p-8. @@ -137,7 +137,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ One example command to run decoding with gpt-oss-20b on v5p-8 with unscanned checkpoint for fast decoding. ```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=decode \ model_name=gpt-oss-20b \ diff --git a/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh b/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh index e4a951beb9..e0f1fc8824 100644 --- a/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh +++ b/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-13b. -# Please make sure you have run end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-13b. +# Please make sure you have run end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama2-13B with the converted checkpoint obtained from end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh. Also, run pretraining of Llama2-13B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh -# Please note that in these two scripts (1_test_llama2_13b.sh and 2_test_llama2_13b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama2_13b.sh and 2_test_llama2_13b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -36,13 +36,13 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_llama2_13b.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` # We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune @@ -51,11 +51,11 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" diff --git a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh index 72ea1c212d..afc3210e0f 100644 --- a/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh +++ b/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-70b. -# Please make sure you have run end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-70b. +# Please make sure you have run end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama2-70B with the converted checkpoint obtained from end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh. Also, run pretraining of Llama2-70B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh -# Please note that in these two scripts (1_test_llama2_70b.sh and 2_test_llama2_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama2_70b.sh and 2_test_llama2_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -42,12 +42,12 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items export ASYNC_CHECKPOINTING=true # True so that jax distributed system is initialized -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune @@ -56,14 +56,14 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama2-70b python3 -m tests.utils.forward_pass_logit_checker --atol=0.2 --rtol=0.2 "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=${ASYNC_CHECKPOINTING} diff --git a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh index 65343a5515..b29d8994aa 100644 --- a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh +++ b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh @@ -4,7 +4,7 @@ # Additionally, this file serves as integration test for context parallelism for training in TPUs in MaxText # The flow of this file is as follows: -# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. +# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. # 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b. # 3. Run decoding from the finetuned weights. # 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding. @@ -25,7 +25,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt -# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. +# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. # You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt` gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ @@ -39,7 +39,7 @@ python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path /t export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. -# We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +# We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true @@ -47,11 +47,11 @@ python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism @@ -70,7 +70,7 @@ python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_ export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items # We run decoding on the fine-tuned parameter checkpoint -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false # We also test whether the forward pass logits match the golden logits for Llama2-7b python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --rtol=0.1 --atol=0.1 diff --git a/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh b/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh index 44712d261a..297bf415ae 100644 --- a/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh +++ b/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-405b. -# Please make sure you have run end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-405b. +# Please make sure you have run end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama3.1-405B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh. Also, run pretraining of Llama3.1-70B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh -# Please note that in these two scripts (1_test_llama3.1_405b.sh and 2_test_llama3.1_405b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3.1_405b.sh and 2_test_llama3.1_405b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -43,7 +43,7 @@ export FINETUNE_RUN_NAME=runner_finetune python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3.1-405B python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype=float32 activations_in_float32=true matmul_precision=float32 weight_dtype=float32 async_checkpointing=false --max_kl_div=1e-4 diff --git a/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh b/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh index 6f70bf2e14..c4c77b69e2 100644 --- a/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh +++ b/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-70b. -# Please make sure you have run end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-70b. +# Please make sure you have run end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama3.1-70B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh. Also, run pretraining of Llama3.1-70B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh -# Please note that in these two scripts (1_test_llama3.1_70b.sh and 2_test_llama3.1_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3.1_70b.sh and 2_test_llama3.1_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -41,12 +41,12 @@ export RUN_NAME=unscanned_chkpt export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # TODO(mohitkhatwani): Fix XLAResourceExhaustion when loading unscanned model -# # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# python3 -m MaxText.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +# python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune diff --git a/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh b/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh index c61729fbde..947ecc7c26 100644 --- a/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh +++ b/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh @@ -33,10 +33,10 @@ JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_P python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # Now we are good to go, serve with performance! -JAX_PLATFORMS=tpu python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED # You can also check the results from scanned version, just double check, not necessary -JAX_PLATFORMS=tpu python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items # Example output # Input `I love to` -> ` read, but I don't have much time. How can I read more books? diff --git a/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh b/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh index 82fab78140..0cf2db2290 100644 --- a/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh +++ b/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh @@ -1,8 +1,8 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v5p-8 and documentation for how to get started with LLama3.1-8b. +# This file is both an integration test that runs once a day on a v5p-8 and documentation for how to get started with LLama3.1-8b. # Additionally, this file serves as integration test for context parallelism for training in TPUs in MaxText -# Please make sure you have run end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh before running commands from this file. +# Please make sure you have run end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of LLama3.1-8B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh. Also, run pretraining of LLama3.1-8B @@ -12,7 +12,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh -# Please note that in these two scripts (1_test_llama3.1_8b.sh and 2_test_llama3.1_8b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3.1_8b.sh and 2_test_llama3.1_8b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -41,12 +41,12 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_llama3.1_8b.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune diff --git a/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh b/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh index 6ff6910080..e8efa91780 100644 --- a/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh +++ b/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh @@ -39,10 +39,10 @@ JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_P python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # Now we are good to go, serve with performance! -JAX_PLATFORMS=tpu python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED # You can also check the results from scanned version, just double check, not necessary -JAX_PLATFORMS=tpu python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items ##### Output from huggingface llama 8B Instruct checkpoint on MaxText: #Input `I love to` -> ` travel and explore new places, but I also love to stay at home and relax. I'm a bit of a homebody, and I enjoy spending time with my family and friends. I'm a bit of a foodie, and I love trying new recipes and experimenting with different flavors and ingredients. I'm also a bit of a movie buff, and I love watching classic films and new releases alike. diff --git a/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh b/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh index a374ba2ee2..4c61c4fac2 100644 --- a/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh +++ b/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.3-70B-Instruct. -# Please make sure you have run end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.3-70B-Instruct. +# Please make sure you have run end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama3.3-70B-Instruct with the converted checkpoint obtained from end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh. Also, run pretraining of Llama3.3-70B-Instruct @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh -# Please note that in these two scripts (1_test_llama3.3_70b.sh and 2_test_llama3.3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3.3_70b.sh and 2_test_llama3.3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -45,12 +45,12 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # export UNSCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/llama3.3-70b-instruct/2025-02-15-07-58/unscanned/checkpoints/0/items # TODO(mohitkhatwani): Fix XLAResourceExhaustion when loading unscanned model -# # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# python3 -m MaxText.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +# python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune diff --git a/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh b/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh index 649faf6f1d..0c53b026f7 100644 --- a/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh +++ b/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Llama3-70b. -# Please make sure you have run end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Llama3-70b. +# Please make sure you have run end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama3-70B with the converted checkpoint obtained from end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh. Also, run pretraining of Llama3-70B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh -# Please note that in these two scripts (1_test_llama3_70b.sh and 2_test_llama3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3_70b.sh and 2_test_llama3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -40,12 +40,12 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_llama3_70b.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune @@ -54,14 +54,14 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3-70B python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 diff --git a/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh b/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh index d6058a6d68..98ee195a8a 100644 --- a/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh +++ b/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh @@ -1,7 +1,7 @@ #!/bin/bash -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Llama3-8b. -# Please make sure you have run end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh before running commands from this file. +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Llama3-8b. +# Please make sure you have run end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh before running commands from this file. # The flow of this file is as follows: # 1. Run decoding, finetuning of Llama3-8B with the converted checkpoint obtained from end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh. Also, run pretraining of Llama3-8B @@ -11,7 +11,7 @@ # Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh -# Please note that in these two scripts (1_test_llama3_8b.sh and 2_test_llama3_8b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# Please note that in these two scripts (1_test_llama3_8b.sh and 2_test_llama3_8b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and # the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. set -ex @@ -40,12 +40,12 @@ export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_llama3_8b.sh export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune @@ -54,14 +54,14 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3-8B python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 diff --git a/end_to_end/tpu/llama4/Run_Llama4.md b/end_to_end/tpu/llama4/Run_Llama4.md index 28f53d32a1..c4660b3bb5 100644 --- a/end_to_end/tpu/llama4/Run_Llama4.md +++ b/end_to_end/tpu/llama4/Run_Llama4.md @@ -65,7 +65,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ In order to run an example decoding with Llama4 Scout, you can use a command such as the following: ```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=decode \ model_name=llama4-17b-16e \ @@ -74,7 +74,7 @@ python3 -m MaxText.decode src/MaxText/configs/base.yml \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ scan_layers=false \ attention=dot_product \ - sparse_matmul=false \ + sparse_matmul=false \ megablox=false \ dtype=bfloat16 \ weight_dtype=bfloat16 \ diff --git a/end_to_end/tpu/mistral/7b/test_mistral-7b.sh b/end_to_end/tpu/mistral/7b/test_mistral-7b.sh index 94fcae05f4..9fef5e75f8 100644 --- a/end_to_end/tpu/mistral/7b/test_mistral-7b.sh +++ b/end_to_end/tpu/mistral/7b/test_mistral-7b.sh @@ -40,7 +40,7 @@ echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN export DATASET_PATH=gs://maxtext-dataset # Run decoding with converted ckpt - matmul implementation -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False # Test whether the forward pass logits match the golden logits - matmul implementation python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4 diff --git a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh index 7535b8e337..5203eeb5c7 100644 --- a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh +++ b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh @@ -35,10 +35,10 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned_ckpt/checkpoints/0/item # Run decoding with converted ckpt - matmul implementation # TODO(ranran): add decoding test for megablox implementation -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false # Run decoding with converted ckpt - dropping implementation -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25 +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25 # Test whether the forward pass logits match the golden logits - matmul implementation python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=11 dtype=float32 megablox=False sparse_matmul=False scan_layers=false --token_size=4 --max_kl_div=3e-3 diff --git a/end_to_end/tpu/qwen/moe/run_qwen_moe.md b/end_to_end/tpu/qwen/moe/run_qwen_moe.md index 9bdcbd2ab7..676cddd8bd 100644 --- a/end_to_end/tpu/qwen/moe/run_qwen_moe.md +++ b/end_to_end/tpu/qwen/moe/run_qwen_moe.md @@ -67,7 +67,7 @@ Decoding To generate text with a trained model, use the `decode` command. The command below is an example for decoding on a v5p-512 slice. ``` -python3 -m MaxText.decode src/MaxText/configs/base.yml\ +python3 -m maxtext.decode src/MaxText/configs/base.yml\ load_parameters_path=gs://your-gcs-bucket/qwen3_maxtext_ckpt/0/items\ tokenizer_type=huggingface\ tokenizer_path=src/MaxText/assets/qwen3-tokenizer\ diff --git a/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh b/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh index ff2a34396c..581f8926e5 100644 --- a/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh +++ b/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh @@ -48,7 +48,7 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ --run_hf_model=True # We can run decoding for unscanned checkpoints. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset @@ -61,4 +61,4 @@ export FINETUNE_RUN_NAME=runner_finetune_${idx} python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} checkpoint_period=5 # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" diff --git a/end_to_end/tpu/test_decode_save_quantized_ckpt.sh b/end_to_end/tpu/test_decode_save_quantized_ckpt.sh index a6afc59c9b..a52161cf09 100644 --- a/end_to_end/tpu/test_decode_save_quantized_ckpt.sh +++ b/end_to_end/tpu/test_decode_save_quantized_ckpt.sh @@ -50,7 +50,7 @@ export OUTFILE="${OUTDIR}/decode.txt" mkdir -p $OUTDIR echo # Run command -${cmd} python3 -m MaxText.decode \ +${cmd} python3 -m maxtext.decode \ "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ diff --git a/end_to_end/tpu/test_sft_trainer.sh b/end_to_end/tpu/test_sft_trainer.sh index cb9a283ec8..82c154a478 100755 --- a/end_to_end/tpu/test_sft_trainer.sh +++ b/end_to_end/tpu/test_sft_trainer.sh @@ -45,7 +45,7 @@ largest_dir="${sorted_dirs[-1]}" FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items # Decode -python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \ +python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \ run_name=${RUN_NAME}-hf-decode \ model_name=${PRE_TRAINED_MODEL} tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} tokenizer_type=huggingface \ load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \ diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/MaxText/examples/demo_decoding.ipynb index 9b913318e9..f10dd963d4 100644 --- a/src/MaxText/examples/demo_decoding.ipynb +++ b/src/MaxText/examples/demo_decoding.ipynb @@ -135,12 +135,12 @@ "\n", "import MaxText as mt\n", "from MaxText import common_types\n", - "from MaxText import inference_utils\n", "from MaxText import maxtext_utils\n", "from MaxText import max_logging\n", "from MaxText import pyconfig\n", "from MaxText.input_pipeline import _input_pipeline_utils\n", "from MaxText.utils.ckpt_conversion import to_maxtext\n", + "from maxtext import inference_utils\n", "\n", "from google.colab import userdata\n", "from huggingface_hub import login\n", diff --git a/src/MaxText/examples/multimodal_gemma3_demo.ipynb b/src/MaxText/examples/multimodal_gemma3_demo.ipynb index a664fe576b..2b289fa4bc 100644 --- a/src/MaxText/examples/multimodal_gemma3_demo.ipynb +++ b/src/MaxText/examples/multimodal_gemma3_demo.ipynb @@ -117,7 +117,7 @@ "metadata": {}, "outputs": [], "source": [ - "!python -m MaxText.decode \\\n", + "!python -m maxtext.decode \\\n", " $MAXTEXT_REPO_ROOT/configs/base.yml \\\n", " model_name=$MODEL_NAME \\\n", " tokenizer_path=assets/tokenizer.gemma3 \\\n", diff --git a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md index 42633bc99b..2e298e101d 100644 --- a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md +++ b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md @@ -1,4 +1,4 @@ -# Checkpoint conversion agent +# Checkpoint conversion agent The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion). ## Quick starts @@ -10,9 +10,9 @@ To begin, you'll need: ``` pip install -q -U "google-genai>=1.0.0" ``` -4. The target/source models must be implemented in MaxText and Hugging Face and we can retrieve random weights to learn its parameter names and tensor shapes. +4. The target/source models must be implemented in MaxText and Hugging Face and we can retrieve random weights to learn its parameter names and tensor shapes. -5. A full run of the agent typically takes about 30 minutes. +5. A full run of the agent typically takes about 30 minutes. ## 1. Prepare the context file @@ -53,25 +53,25 @@ You can automatically verify the output by comparing the generated code against ```bash python3 -m MaxText.experimental.agent.ckpt_conversion_agent.evaluation --files ground_truth/.py \ - outputs/hook_fn.py --api_key= --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent + outputs/hook_fn.py --api_key= --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent ``` ### Manual Debugging (No Ground-Truth Code) If a ground-truth version isn't available, you'll need to debug the conversion manually. The recommended process is to: -1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#adding-support-for-new-models). +1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#adding-support-for-new-models). 2. Execute the conversion process layer-by-layer, using [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#hugging-face-to-maxtext) or [to_huggingface.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#maxtext-to-hugging-face). - - If the tensor shape are not matched after conversion, error message will print out the parameter name that caused error. + - If the tensor shape are not matched after conversion, error message will print out the parameter name that caused error. -3. After the conversion is done, run a decode to check the correctness of the generated code. +3. After the conversion is done, run a decode to check the correctness of the generated code. Example command: ```bash -python3 -m MaxText.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path=assets/tokenizer.gemma3 \ +python3 -m maxtext.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path=assets/tokenizer.gemma3 \ load_parameters_path= per_device_batch_size=1 run_name=ht_test \ max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true \ prompt='I love to' attention='dot_product' ``` -If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean/max/min values for debugging. +If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean/max/min values for debugging. 4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: ```bash @@ -92,14 +92,14 @@ python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ * `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`). * `scan_layers`: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false). * `use_multimodal`: Indicates if multimodality is used. - * `--run_hf_model`: Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits. + * `--run_hf_model`: Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits. * `--hf_model_path`: The path to the Hugging Face checkpoint. * `--max_kl_div`: Max KL divergence tolerance during comparisons. ## Debugging tips -1. If a response from Gemini is `None`, wait for a moment and retry. +1. If a response from Gemini is `None`, wait for a moment and retry. 2. If a converted checkpoint loads without errors but produces incorrect output, consider these common issues: diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/config.py b/src/MaxText/experimental/agent/integrative_rag_agent/config.py index 16ea98a59c..21dc6a8aa8 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/config.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/config.py @@ -103,7 +103,7 @@ # for converting PyTorch code to JAX block_for_rag = [ "src/MaxText/layers", # Neural network layers and building blocks - "src/MaxText/inference", # Inference and prediction code + "src/maxtext/inference", # Inference and prediction code "src/MaxText/common_types.py", # Common data types and structures "src/MaxText/maxtext_utils.py", # Utility functions and helpers ] diff --git a/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py b/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py index 955c76643d..1906e92906 100644 --- a/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py +++ b/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py @@ -93,7 +93,7 @@ def visit_Attribute(self, node): if base_name in self.git_aliases: # It's an external dependency. We need to format it with the attribute path. # Example: base_name='page_manager', attr_chain='PageState' - # self.git_dependencies['page_manager'] might be 'src/MaxText/inference/page_manager.py#page_manager' + # self.git_dependencies['page_manager'] might be 'src/maxtext/inference/page_manager.py#page_manager' path, obj = self.git_dependencies[base_name].split("#", 1) # As per the user request, we append the attribute access to the object name. @@ -197,9 +197,9 @@ def convert_package_to_path(self, path): """Convert an absolute import line to a mapping of names to file anchors. Example: - "from MaxText.inference import page_manager, utils" -> - {"page_manager": "src/MaxText/inference.py#page_manager", - "utils": "src/MaxText/inference.py#utils"} + "from maxtext.inference import page_manager, utils" -> + {"page_manager": "src/maxtext/inference.py#page_manager", + "utils": "src/maxtext/inference.py#utils"} Args: path (str): A normalized absolute import string. @@ -215,8 +215,8 @@ def convert_package_to_path(self, path): # or a module 'pkg' corresponds to 'path_form/pkg.py' # The logic in get_absolute_imports should ideally resolve this ambiguity. # A heuristic could be used here (e.g., checking casing) but we stick to the current logic. - # The user's example `from MaxText.inference import page_manager` creates a path - # `src/MaxText/inference.py#page_manager`, which is what the new visitor expects to correct. + # The user's example `from maxtext.inference import page_manager` creates a path + # `src/maxtext/inference.py#page_manager`, which is what the new visitor expects to correct. import_dict[pkg.strip()] = path_form + ".py#" + pkg.strip() return import_dict diff --git a/src/MaxText/experimental/rl/grpo_trainer.py b/src/MaxText/experimental/rl/grpo_trainer.py index bbad5620b3..12ca37368b 100644 --- a/src/MaxText/experimental/rl/grpo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_trainer.py @@ -75,7 +75,6 @@ from MaxText import train_utils from MaxText import pyconfig from MaxText.utils import gcs_utils -from MaxText.inference import offline_engine from MaxText.experimental.rl import grpo_input_pipeline from MaxText.experimental.rl import grpo_utils from MaxText.globals import EPS @@ -91,6 +90,7 @@ ) from maxtext.common.metric_logger import MetricLogger from maxtext.common.vertex_tensorboard import VertexTensorboardManager +from maxtext.inference import offline_engine # pylint: disable=too-many-positional-arguments diff --git a/src/MaxText/experimental/rl/grpo_utils.py b/src/MaxText/experimental/rl/grpo_utils.py index fb6b748a5c..ced9fc10e6 100644 --- a/src/MaxText/experimental/rl/grpo_utils.py +++ b/src/MaxText/experimental/rl/grpo_utils.py @@ -24,7 +24,7 @@ from MaxText import max_logging from MaxText import max_utils from MaxText.common_types import DecoderBlockType -from MaxText.inference.offline_engine import InputData +from maxtext.inference.offline_engine import InputData from pathwaysutils.experimental import reshard as experimental_reshard from pathwaysutils.experimental import split_by_mesh_axis diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 819f430983..43dd69403b 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -61,10 +61,7 @@ AttentionType, DEFAULT_MASK_VALUE, ) -from MaxText.inference import kvcache -from MaxText.inference import page_manager -from MaxText.inference import paged_attention -from MaxText.inference.kvcache import KVQuant + from MaxText.sharding import create_sharding from MaxText.layers import nnx_wrappers from MaxText.layers.attentions import Attention @@ -72,6 +69,10 @@ from MaxText.layers.linears import DenseGeneral from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import kvcache +from maxtext.inference import page_manager +from maxtext.inference import paged_attention +from maxtext.inference.kvcache import KVQuant class Indexer(nnx.Module): diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index b44667b281..2354c93d67 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -69,8 +69,7 @@ Q_LENGTH, Q_LENGTH_NO_EXP, ) -from MaxText.inference import page_manager -from MaxText.inference.kvcache import KVQuant, KVTensor + from MaxText.kernels import jax_flash_attention from MaxText.kernels.ragged_attention import ragged_gqa from MaxText.kernels.ragged_attention import ragged_mha @@ -78,6 +77,8 @@ from MaxText.layers.initializers import variable_to_logically_partitioned from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.sharding import logical_to_mesh_axes, maybe_shard_with_name +from maxtext.inference import page_manager +from maxtext.inference.kvcache import KVQuant, KVTensor import numpy as np from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 194960d58c..6baa1c5387 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -54,10 +54,6 @@ AttentionType, ) from MaxText.sharding import maybe_shard_with_logical, create_sharding -from MaxText.inference import kvcache -from MaxText.inference import page_manager -from MaxText.inference import paged_attention -from MaxText.inference.kvcache import KVQuant from MaxText.layers import nnx_wrappers from MaxText.layers.attention_op import AttentionOp from MaxText.layers.embeddings import ( @@ -72,6 +68,8 @@ from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import kvcache, page_manager, paged_attention +from maxtext.inference.kvcache import KVQuant # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index a304fee56f..b491fd0646 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -33,7 +33,6 @@ from MaxText import max_logging from MaxText import max_utils from MaxText.sharding import create_sharding -from MaxText.inference import page_manager from MaxText.layers import linears from MaxText.layers import normalizations from MaxText.layers import quantizations @@ -61,6 +60,7 @@ simple_layer, olmo3, ) +from maxtext.inference import page_manager # ------------------------------------------------------------------------------ # The network: Decoder Definitions diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index ddad5866fd..b5b6553f58 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -27,7 +27,6 @@ from MaxText import max_utils from MaxText.common_types import Config from MaxText.common_types import MODEL_MODE_PREFILL -from MaxText.inference import page_manager from MaxText.layers import attention_mla from MaxText.layers import initializers from MaxText.layers import linears @@ -37,6 +36,7 @@ from MaxText.layers.linears import Dropout from MaxText.layers.normalizations import RMSNorm from MaxText.sharding import maybe_shard_with_logical, create_sharding +from maxtext.inference import page_manager # ----------------------------------------- # The Decoder Layer for DeepSeek v3 diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index 21bfe7c7f0..33a5f843f6 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -23,7 +23,6 @@ from MaxText import common_types from MaxText import max_utils from MaxText.common_types import Config -from MaxText.inference import page_manager from MaxText.layers import attention_mla from MaxText.layers import initializers from MaxText.layers import linears @@ -32,6 +31,7 @@ from MaxText.layers import nnx_wrappers from MaxText.layers import quantizations from MaxText.sharding import maybe_shard_with_logical, create_sharding +from maxtext.inference import page_manager class DeepSeekBatchSplitGenericLayer(nnx.Module): """Generic DeepSeek layer with Multi-Head Latent Attention. diff --git a/src/MaxText/layers/llama2.py b/src/MaxText/layers/llama2.py index 3e767bac8e..5f03698a96 100644 --- a/src/MaxText/layers/llama2.py +++ b/src/MaxText/layers/llama2.py @@ -23,7 +23,7 @@ from flax import nnx -from MaxText.inference import page_manager +from maxtext.inference import page_manager from MaxText.common_types import Config from MaxText import max_utils from MaxText.sharding import maybe_shard_with_logical, create_sharding diff --git a/src/MaxText/layers/llama4.py b/src/MaxText/layers/llama4.py index 957824c00c..a920d56d5b 100644 --- a/src/MaxText/layers/llama4.py +++ b/src/MaxText/layers/llama4.py @@ -27,7 +27,6 @@ from MaxText.common_types import Config, Array, MODEL_MODE_TRAIN, AttentionType from MaxText import max_utils -from MaxText.inference import page_manager from MaxText.layers import initializers from MaxText.layers import nnx_wrappers from MaxText.layers import linears @@ -39,7 +38,7 @@ from MaxText.layers.linears import Dropout from MaxText.layers.moe import RoutedAndSharedMoE from MaxText.common_types import MODEL_MODE_PREFILL - +from maxtext.inference import page_manager #### Multi modal model implementation diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 2d84eda09a..e5795e4df2 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -26,7 +26,6 @@ from MaxText.layers import initializers from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR -from MaxText.inference import page_manager from MaxText import multimodal_utils from MaxText import max_utils from MaxText.layers import nnx_wrappers @@ -35,6 +34,7 @@ from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen, AudioEncoder, audio_encoder_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from maxtext.inference import page_manager # ------------------------------------------------------------------------------ # The network: Transformer Definitions diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index eb06a19af8..c1d0314925 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -37,7 +37,7 @@ import flax.linen as nn from MaxText.common_types import DType, Config -from MaxText.inference.kvcache import KVQuant +from maxtext.inference.kvcache import KVQuant # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index e5dec8b40c..5ed59ffcc0 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -38,11 +38,12 @@ from MaxText.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding from MaxText.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.inference import page_manager from MaxText.layers.attentions import Attention from MaxText.layers.linears import DenseGeneral, MlpBlock from MaxText.layers.moe import RoutedMoE from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned +from maxtext.inference import page_manager + # ----------------------------------------- # Qwen3-Next Layer Implementations # ----------------------------------------- diff --git a/src/MaxText/maxengine.py b/src/MaxText/maxengine.py index 36d6652942..ef8ba9e5eb 100644 --- a/src/MaxText/maxengine.py +++ b/src/MaxText/maxengine.py @@ -43,17 +43,16 @@ from jetstream.engine.tokenizer_pb2 import TokenizerParameters from jetstream.engine.tokenizer_pb2 import TokenizerType -from MaxText import inference_utils from MaxText import max_utils from MaxText import maxtext_utils from MaxText import multimodal_utils from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.inference.page_manager import PageManager, PageState from MaxText.layers import models, quantizations from MaxText.utils import lora_utils - +from maxtext import inference_utils +from maxtext.inference.page_manager import PageManager, PageState warnings.simplefilter("ignore", category=FutureWarning) DecodeState = Any diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index bd4f102e21..c59729933d 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -43,7 +43,7 @@ from MaxText.configs import types from MaxText.utils import gcs_utils from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE -from MaxText.inference.page_manager import PageState +from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" diff --git a/src/MaxText/prefill_packing.py b/src/MaxText/prefill_packing.py index e39fc564ac..bd5a8a4189 100644 --- a/src/MaxText/prefill_packing.py +++ b/src/MaxText/prefill_packing.py @@ -265,7 +265,7 @@ def _process_bucket( ) -> tuple[list[tuple[engine_api.ResultTokens, int]], DecodeState]: """Process all items in a bucket.""" # pylint: disable=import-outside-toplevel - from MaxText.inference.offline_engine import PrefillResult # type: ignore + from maxtext.inference.offline_engine import PrefillResult # type: ignore slots = bucket.slots lengths = [len(prompt) for prompt in bucket.token_ids] diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index fd236adb95..ec03df0e6d 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -30,7 +30,7 @@ from MaxText.common_types import DecoderBlockType, ShardMode from MaxText.configs import types from MaxText.configs.types import MaxTextConfig -from MaxText.inference_utils import str2bool +from maxtext.inference_utils import str2bool logger = logging.getLogger(__name__) logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) diff --git a/src/MaxText/scratch_code/demo_from_config.ipynb b/src/MaxText/scratch_code/demo_from_config.ipynb index d0f22a5dea..9a9ca50279 100644 --- a/src/MaxText/scratch_code/demo_from_config.ipynb +++ b/src/MaxText/scratch_code/demo_from_config.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "0ab2e1dd", "metadata": {}, "outputs": [ @@ -57,7 +57,7 @@ "from MaxText import max_logging\n", "from MaxText import common_types\n", "import jax\n", - "from MaxText import inference_utils" + "from maxtext import inference_utils" ] }, { diff --git a/src/MaxText/scratch_code/gemma_7b.sh b/src/MaxText/scratch_code/gemma_7b.sh index 7fd4bf9e0d..7d3e012d90 100644 --- a/src/MaxText/scratch_code/gemma_7b.sh +++ b/src/MaxText/scratch_code/gemma_7b.sh @@ -3,6 +3,6 @@ export M_PER_DEVICE_BATCH_SIZE=24 export M_MAX_PREFILL_PREDICT_LENGTH=1024 export M_MAX_TARGET_LENGTH=2048 -#python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false +#python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false python3 -m MaxText.maxengine_server "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false diff --git a/src/MaxText/scratch_code/run_inference_microbenchmark.sh b/src/MaxText/scratch_code/run_inference_microbenchmark.sh index 15cfcc8cab..00ab321efe 100644 --- a/src/MaxText/scratch_code/run_inference_microbenchmark.sh +++ b/src/MaxText/scratch_code/run_inference_microbenchmark.sh @@ -1,5 +1,5 @@ -# llama2-7b -python3 -m MaxText.inference_microbenchmark \ +# llama2-7b +python3 -m maxtext.inference_microbenchmark \ "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ async_checkpointing=false \ attention=autoselected \ diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh index 03a38b650b..79d51c9a63 100644 --- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh +++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh @@ -27,14 +27,14 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ per_device_batch_size="${PER_DEVICE_BATCH_SIZE}" \ run_name="run_to_mt" \ async_checkpointing="${ASYNC_CHECKPOINTING}" \ - scan_layers="${SCAN_LAYERS}" + scan_layers="${SCAN_LAYERS}" echo "--- Checkpoint Conversion Complete ---" # --- Step 2 (Optional): Decode using the Converted Checkpoint --- echo "--- Starting Decoding ---" -python3 -m MaxText.decode \ +python3 -m maxtext.decode \ ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml \ model_name="${MODEL_NAME}" \ tokenizer_path="${TOKENIZER_PATH}" \ diff --git a/src/MaxText/utils/ckpt_conversion/to_maxtext.py b/src/MaxText/utils/ckpt_conversion/to_maxtext.py index 6efed5533b..595de466c0 100644 --- a/src/MaxText/utils/ckpt_conversion/to_maxtext.py +++ b/src/MaxText/utils/ckpt_conversion/to_maxtext.py @@ -79,10 +79,10 @@ from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_TRAIN -from MaxText.inference_utils import str2bool from MaxText.layers import models, quantizations from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys +from maxtext.inference_utils import str2bool from maxtext.common import checkpointing jax.config.update("jax_platform_name", "cpu") diff --git a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py index feb565d2e3..395d1772fc 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_ckpt.py @@ -41,8 +41,8 @@ from safetensors import safe_open from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt +from maxtext.inference_utils import str2bool absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log diff --git a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py index 5c2f1155ff..2a9194b08f 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_deepseek_family_unscanned_ckpt.py @@ -39,7 +39,7 @@ from MaxText.utils.ckpt_scripts import convert_deepseek_family_ckpt as ds_ckpt from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt from MaxText import max_logging -from MaxText.inference_utils import str2bool +from maxtext.inference_utils import str2bool from safetensors import safe_open absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log diff --git a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py index 5910259a1f..549ffeda63 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py @@ -37,9 +37,9 @@ from tqdm import tqdm from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt import MODEL_PARAMS_DICT, _hf_to_maxtext_mapping, _pt_to_np +from maxtext.inference_utils import str2bool absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log diff --git a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py index a4977e8182..fe9fd75bee 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_unscanned_ckpt.py @@ -37,8 +37,8 @@ from tqdm import tqdm from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint +from maxtext.inference_utils import str2bool absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py index e7d089c7b9..5e4a8126da 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_moe.py @@ -33,8 +33,8 @@ from tqdm import tqdm from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt +from maxtext.inference_utils import str2bool # Static model parameters dictionary MODEL_PARAMS_DICT = { diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py index 78725c4e5d..346e29aa4b 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py @@ -37,7 +37,7 @@ from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt from MaxText import max_logging -from MaxText.inference_utils import str2bool +from maxtext.inference_utils import str2bool MODEL_PARAMS_DICT = { "qwen3-next-80b-a3b": { diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py index c49ba151cb..495ab99f46 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py @@ -38,9 +38,9 @@ from typing import Any, Dict from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.convert_qwen3_next_scanned import MODEL_PARAMS_DICT +from maxtext.inference_utils import str2bool # NOTE: numpy doesn't have native support for bfloat16, so diff --git a/src/MaxText/utils/ckpt_scripts/llama4_ckpt_unscanned.py b/src/MaxText/utils/ckpt_scripts/llama4_ckpt_unscanned.py index 15a9e4d491..f1ec65b4d3 100644 --- a/src/MaxText/utils/ckpt_scripts/llama4_ckpt_unscanned.py +++ b/src/MaxText/utils/ckpt_scripts/llama4_ckpt_unscanned.py @@ -54,8 +54,8 @@ from tqdm import tqdm from MaxText import max_logging -from MaxText.inference_utils import str2bool from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint, MODEL_PARAMS_DICT +from maxtext.inference_utils import str2bool SIMULATED_CPU_DEVICES_COUNT = 16 diff --git a/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py b/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py index a257f11145..147a06a3ca 100644 --- a/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py @@ -62,8 +62,8 @@ from MaxText import max_logging from MaxText import max_utils -from MaxText.inference_utils import str2bool from MaxText.utils import gcs_utils +from maxtext.inference_utils import str2bool from maxtext.common import checkpointing MODEL_PARAMS_DICT = { diff --git a/src/MaxText/decode.py b/src/maxtext/decode.py similarity index 100% rename from src/MaxText/decode.py rename to src/maxtext/decode.py diff --git a/src/MaxText/inference/__init__.py b/src/maxtext/inference/__init__.py similarity index 100% rename from src/MaxText/inference/__init__.py rename to src/maxtext/inference/__init__.py diff --git a/src/MaxText/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml b/src/maxtext/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml similarity index 100% rename from src/MaxText/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml rename to src/maxtext/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml diff --git a/src/MaxText/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml b/src/maxtext/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml similarity index 100% rename from src/MaxText/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml rename to src/maxtext/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml diff --git a/src/MaxText/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml b/src/maxtext/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml similarity index 100% rename from src/MaxText/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml rename to src/maxtext/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml diff --git a/src/MaxText/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml b/src/maxtext/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml similarity index 100% rename from src/MaxText/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml rename to src/maxtext/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml diff --git a/src/MaxText/inference/decode_multi.py b/src/maxtext/inference/decode_multi.py similarity index 100% rename from src/MaxText/inference/decode_multi.py rename to src/maxtext/inference/decode_multi.py diff --git a/src/MaxText/inference/gpu/README.md b/src/maxtext/inference/gpu/README.md similarity index 100% rename from src/MaxText/inference/gpu/README.md rename to src/maxtext/inference/gpu/README.md diff --git a/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh similarity index 97% rename from src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh rename to src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh index d124b849de..a08f49babf 100755 --- a/src/MaxText/inference/gpu/microbenchmark_llama2-70b_h100-8.sh +++ b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh @@ -102,7 +102,7 @@ cd $(dirname $0)/../../../ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION --xla_disable_hlo_passes=rematerialization" \ TF_FORCE_GPU_ALLOW_GROWTH=true \ XLA_PYTHON_CLIENT_MEM_FRACTION=0.94 \ -python3 -m MaxText.inference_microbenchmark $MAXENGINE_CONFIG_FILEPATH \ +python3 -m maxtext.inference_microbenchmark $MAXENGINE_CONFIG_FILEPATH \ base_output_directory=$BASE_OUTPUT_DIRECTORY \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2 \ model_name='llama2-70b' \ @@ -125,4 +125,4 @@ python3 -m MaxText.inference_microbenchmark $MAXENGINE_CONFIG_FILEPATH \ kv_quant_dtype=$KV_QUANT_DTYPE \ quantize_kvcache=$QUANTIZE_KVCACHE \ quantization=$QUANTIZATION$PROFILER_STR \ - gcs_metrics=$GCS_METRICS + gcs_metrics=$GCS_METRICS diff --git a/src/MaxText/inference/jetstream_pathways/README.md b/src/maxtext/inference/jetstream_pathways/README.md similarity index 100% rename from src/MaxText/inference/jetstream_pathways/README.md rename to src/maxtext/inference/jetstream_pathways/README.md diff --git a/src/MaxText/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh b/src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh similarity index 100% rename from src/MaxText/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh rename to src/maxtext/inference/jetstream_pathways/jetstream_pathways_entrypoint.sh diff --git a/src/MaxText/inference/kvcache.py b/src/maxtext/inference/kvcache.py similarity index 100% rename from src/MaxText/inference/kvcache.py rename to src/maxtext/inference/kvcache.py diff --git a/src/MaxText/inference/maxengine_server/README.md b/src/maxtext/inference/maxengine_server/README.md similarity index 100% rename from src/MaxText/inference/maxengine_server/README.md rename to src/maxtext/inference/maxengine_server/README.md diff --git a/src/MaxText/inference/maxengine_server/maxengine_server_entrypoint.sh b/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh similarity index 100% rename from src/MaxText/inference/maxengine_server/maxengine_server_entrypoint.sh rename to src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh diff --git a/src/MaxText/inference/offline_engine.py b/src/maxtext/inference/offline_engine.py similarity index 100% rename from src/MaxText/inference/offline_engine.py rename to src/maxtext/inference/offline_engine.py diff --git a/src/MaxText/inference/page_manager.py b/src/maxtext/inference/page_manager.py similarity index 100% rename from src/MaxText/inference/page_manager.py rename to src/maxtext/inference/page_manager.py diff --git a/src/MaxText/inference/paged_attention.py b/src/maxtext/inference/paged_attention.py similarity index 99% rename from src/MaxText/inference/paged_attention.py rename to src/maxtext/inference/paged_attention.py index 3698011c07..640fbcfd26 100644 --- a/src/MaxText/inference/paged_attention.py +++ b/src/maxtext/inference/paged_attention.py @@ -28,8 +28,8 @@ from flax import linen as nn from flax import nnx -from MaxText.inference import page_manager -from MaxText.inference import paged_attention_kernel_v2 +from maxtext.inference import page_manager +from maxtext.inference import paged_attention_kernel_v2 from MaxText.sharding import logical_to_mesh_axes from MaxText.common_types import Array, DType, AxisNames, BATCH, LENGTH, HEAD, D_KV, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText.layers.initializers import variable_to_logically_partitioned diff --git a/src/MaxText/inference/paged_attention_kernel_v2.py b/src/maxtext/inference/paged_attention_kernel_v2.py similarity index 100% rename from src/MaxText/inference/paged_attention_kernel_v2.py rename to src/maxtext/inference/paged_attention_kernel_v2.py diff --git a/src/MaxText/inference/scripts/decode_multi.py b/src/maxtext/inference/scripts/decode_multi.py similarity index 100% rename from src/MaxText/inference/scripts/decode_multi.py rename to src/maxtext/inference/scripts/decode_multi.py diff --git a/src/MaxText/inference/scripts/notebooks/sharding_utils.ipynb b/src/maxtext/inference/scripts/notebooks/sharding_utils.ipynb similarity index 100% rename from src/MaxText/inference/scripts/notebooks/sharding_utils.ipynb rename to src/maxtext/inference/scripts/notebooks/sharding_utils.ipynb diff --git a/src/MaxText/inference/scripts/sharding_utils.py b/src/maxtext/inference/scripts/sharding_utils.py similarity index 100% rename from src/MaxText/inference/scripts/sharding_utils.py rename to src/maxtext/inference/scripts/sharding_utils.py diff --git a/src/MaxText/inference/scripts/test_sharding_utils.py b/src/maxtext/inference/scripts/test_sharding_utils.py similarity index 99% rename from src/MaxText/inference/scripts/test_sharding_utils.py rename to src/maxtext/inference/scripts/test_sharding_utils.py index 535fb0abce..582bb06921 100644 --- a/src/MaxText/inference/scripts/test_sharding_utils.py +++ b/src/maxtext/inference/scripts/test_sharding_utils.py @@ -23,7 +23,7 @@ import unittest -from MaxText.inference.scripts.sharding_utils import calculate_matmul_resources, latency_bound_comms +from maxtext.inference.scripts.sharding_utils import calculate_matmul_resources, latency_bound_comms # Common test parameters M, K, F = 64, 128, 256 diff --git a/src/MaxText/inference_microbenchmark.py b/src/maxtext/inference_microbenchmark.py similarity index 100% rename from src/MaxText/inference_microbenchmark.py rename to src/maxtext/inference_microbenchmark.py diff --git a/src/MaxText/inference_microbenchmark_sweep.py b/src/maxtext/inference_microbenchmark_sweep.py similarity index 99% rename from src/MaxText/inference_microbenchmark_sweep.py rename to src/maxtext/inference_microbenchmark_sweep.py index 36febfca02..3f2de28a56 100644 --- a/src/MaxText/inference_microbenchmark_sweep.py +++ b/src/maxtext/inference_microbenchmark_sweep.py @@ -22,7 +22,7 @@ import jax -from MaxText import inference_microbenchmark +from maxtext import inference_microbenchmark from MaxText import pyconfig try: diff --git a/src/MaxText/inference_mlperf/README.md b/src/maxtext/inference_mlperf/README.md similarity index 97% rename from src/MaxText/inference_mlperf/README.md rename to src/maxtext/inference_mlperf/README.md index 8731913fd1..326a68e089 100644 --- a/src/MaxText/inference_mlperf/README.md +++ b/src/maxtext/inference_mlperf/README.md @@ -64,7 +64,7 @@ cd ~ git clone https://github.com/AI-Hypercomputer/maxtext.git cd maxtext bash setup.sh -python3 -m pip install -r src/MaxText/inference_mlperf/requirements.txt +python3 -m pip install -r src/maxtext/inference_mlperf/requirements.txt ``` ### Generate quantized checkpoint @@ -100,7 +100,7 @@ export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat # other tokenizers under src/MaxText/assets/ directory. export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"'/tokenizer.llama2' cd maxtext && \ -python3 -m MaxText.decode src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} ``` Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable. @@ -141,7 +141,7 @@ huggingface-cli login --token $HUGGING_FACE_TOKEN #### For trillium #### LLama2-70b: ``` -cd ~/maxtext/src/MaxText/inference_mlperf/trillium +cd ~/maxtext/src/maxtext/inference_mlperf/trillium ``` ##### Test Run diff --git a/src/MaxText/inference_mlperf/__init__.py b/src/maxtext/inference_mlperf/__init__.py similarity index 100% rename from src/MaxText/inference_mlperf/__init__.py rename to src/maxtext/inference_mlperf/__init__.py diff --git a/src/MaxText/inference_mlperf/evaluate-accuracy-fast.py b/src/maxtext/inference_mlperf/evaluate-accuracy-fast.py similarity index 100% rename from src/MaxText/inference_mlperf/evaluate-accuracy-fast.py rename to src/maxtext/inference_mlperf/evaluate-accuracy-fast.py diff --git a/src/MaxText/inference_mlperf/evaluate-accuracy.py b/src/maxtext/inference_mlperf/evaluate-accuracy.py similarity index 100% rename from src/MaxText/inference_mlperf/evaluate-accuracy.py rename to src/maxtext/inference_mlperf/evaluate-accuracy.py diff --git a/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh b/src/maxtext/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh similarity index 95% rename from src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh rename to src/maxtext/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh index 06d49143c2..71901049a6 100755 --- a/src/MaxText/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh +++ b/src/maxtext/inference_mlperf/gpu/benchmarks_llama2-70b-h100_8.sh @@ -108,14 +108,14 @@ run_benchmark() { local type=$1 case "$type" in "performance") - $cmd bash ./MaxText/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_performance_${RUN_DESC} + $cmd bash ./maxtext/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_performance_${RUN_DESC} ;; "audit") - $cmd bash ./MaxText/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_audit_${RUN_DESC} -d + $cmd bash ./maxtext/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_audit_${RUN_DESC} -d ;; "accuracy") export HF_CKPT="meta-llama/Llama-2-70b-chat-hf" - $cmd bash ./MaxText/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_accuracy_${RUN_DESC} -a + $cmd bash ./maxtext/inference_mlperf/llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_accuracy_${RUN_DESC} -a ;; esac } diff --git a/src/MaxText/inference_mlperf/llama_offline_run.sh b/src/maxtext/inference_mlperf/llama_offline_run.sh similarity index 97% rename from src/MaxText/inference_mlperf/llama_offline_run.sh rename to src/maxtext/inference_mlperf/llama_offline_run.sh index 26b5be29ea..f695c93004 100755 --- a/src/MaxText/inference_mlperf/llama_offline_run.sh +++ b/src/maxtext/inference_mlperf/llama_offline_run.sh @@ -117,7 +117,7 @@ else export DATASET_TYPE=full export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl export TOTAL_SAMPLE_COUNT=24576 - export USER_CONFIG=user.conf # NOTE: you may need to change this path(e.g. `src/MaxText/inference_mlperf/user.conf`) + export USER_CONFIG=user.conf # NOTE: you may need to change this path(e.g. `src/maxtext/inference_mlperf/user.conf`) fi # LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" @@ -142,7 +142,7 @@ run_loadgen() { echo "PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES: ${PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES}" echo "MAXENGINE_ARGS: ${MAXENGINE_ARGS}" echo - ${CMD} python3 -m MaxText.inference_mlperf.offline_mode \ + ${CMD} python3 -m maxtext.inference_mlperf.offline_mode \ --maxengine_config_filepath=${MAXENGINE_CONFIG_FILEPATH} \ --mlperf_test_mode=${TEST_MODE} \ --input_mode tokenized \ @@ -191,7 +191,7 @@ run_loadgen_accuracy () { EVAL_SCRIPT="evaluate-accuracy" fi echo - ${CMD} python3 -m MaxText.inference_mlperf.${EVAL_SCRIPT} \ + ${CMD} python3 -m maxtext.inference_mlperf.${EVAL_SCRIPT} \ --checkpoint-path ${HF_CKPT} \ --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log diff --git a/src/MaxText/inference_mlperf/matmul/__init__.py b/src/maxtext/inference_mlperf/matmul/__init__.py similarity index 100% rename from src/MaxText/inference_mlperf/matmul/__init__.py rename to src/maxtext/inference_mlperf/matmul/__init__.py diff --git a/src/MaxText/inference_mlperf/matmul/matmul_dtypes.py b/src/maxtext/inference_mlperf/matmul/matmul_dtypes.py similarity index 96% rename from src/MaxText/inference_mlperf/matmul/matmul_dtypes.py rename to src/maxtext/inference_mlperf/matmul/matmul_dtypes.py index c5abaeb65c..b9247ccc13 100644 --- a/src/MaxText/inference_mlperf/matmul/matmul_dtypes.py +++ b/src/maxtext/inference_mlperf/matmul/matmul_dtypes.py @@ -16,7 +16,7 @@ import jax -from MaxText.inference_mlperf.matmul import timing_util +from maxtext.inference_mlperf.matmul import timing_util if __name__ == "__main__": _PROFILE = False diff --git a/src/MaxText/inference_mlperf/matmul/matmul_sharding.py b/src/maxtext/inference_mlperf/matmul/matmul_sharding.py similarity index 100% rename from src/MaxText/inference_mlperf/matmul/matmul_sharding.py rename to src/maxtext/inference_mlperf/matmul/matmul_sharding.py diff --git a/src/MaxText/inference_mlperf/matmul/timing_util.py b/src/maxtext/inference_mlperf/matmul/timing_util.py similarity index 100% rename from src/MaxText/inference_mlperf/matmul/timing_util.py rename to src/maxtext/inference_mlperf/matmul/timing_util.py diff --git a/src/MaxText/inference_mlperf/mixtral_offline_run.sh b/src/maxtext/inference_mlperf/mixtral_offline_run.sh similarity index 100% rename from src/MaxText/inference_mlperf/mixtral_offline_run.sh rename to src/maxtext/inference_mlperf/mixtral_offline_run.sh diff --git a/src/MaxText/inference_mlperf/offline_inference.py b/src/maxtext/inference_mlperf/offline_inference.py similarity index 100% rename from src/MaxText/inference_mlperf/offline_inference.py rename to src/maxtext/inference_mlperf/offline_inference.py diff --git a/src/MaxText/inference_mlperf/offline_mode.py b/src/maxtext/inference_mlperf/offline_mode.py similarity index 99% rename from src/MaxText/inference_mlperf/offline_mode.py rename to src/maxtext/inference_mlperf/offline_mode.py index cce3d2ae29..57ff873c3d 100644 --- a/src/MaxText/inference_mlperf/offline_mode.py +++ b/src/maxtext/inference_mlperf/offline_mode.py @@ -36,7 +36,7 @@ # pylint: disable=no-name-in-module from MaxText.maxengine import create_engine_from_config_flags -from MaxText.inference_mlperf import offline_inference +from maxtext.inference_mlperf import offline_inference warnings.simplefilter("ignore", category=FutureWarning) diff --git a/src/MaxText/inference_mlperf/requirements.txt b/src/maxtext/inference_mlperf/requirements.txt similarity index 100% rename from src/MaxText/inference_mlperf/requirements.txt rename to src/maxtext/inference_mlperf/requirements.txt diff --git a/src/MaxText/inference_mlperf/trillium/__init__.py b/src/maxtext/inference_mlperf/trillium/__init__.py similarity index 100% rename from src/MaxText/inference_mlperf/trillium/__init__.py rename to src/maxtext/inference_mlperf/trillium/__init__.py diff --git a/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh b/src/maxtext/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh similarity index 98% rename from src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh rename to src/maxtext/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh index 2a98616e20..6cb9a3742b 100644 --- a/src/MaxText/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh +++ b/src/maxtext/inference_mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# NOTE: please check the README located at src/MaxText/inference_mlperf/README.md for instructions on how +# NOTE: please check the README located at src/maxtext/inference_mlperf/README.md for instructions on how # to set up the environment before running this script. # Run command: # bash benchmarks_llama2-70b-trillium_2x4.sh [-b benchmark_type] diff --git a/src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh b/src/maxtext/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh similarity index 100% rename from src/MaxText/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh rename to src/maxtext/inference_mlperf/trillium/microbenchmarks_llama2-70b-trillium_2x4.sh diff --git a/src/MaxText/inference_mlperf/trillium/select_xla_flags.py b/src/maxtext/inference_mlperf/trillium/select_xla_flags.py similarity index 100% rename from src/MaxText/inference_mlperf/trillium/select_xla_flags.py rename to src/maxtext/inference_mlperf/trillium/select_xla_flags.py diff --git a/src/MaxText/inference_mlperf/user.conf b/src/maxtext/inference_mlperf/user.conf similarity index 100% rename from src/MaxText/inference_mlperf/user.conf rename to src/maxtext/inference_mlperf/user.conf diff --git a/src/MaxText/inference_mlperf/user100.conf b/src/maxtext/inference_mlperf/user100.conf similarity index 100% rename from src/MaxText/inference_mlperf/user100.conf rename to src/maxtext/inference_mlperf/user100.conf diff --git a/src/MaxText/inference_mlperf/user5000.conf b/src/maxtext/inference_mlperf/user5000.conf similarity index 100% rename from src/MaxText/inference_mlperf/user5000.conf rename to src/maxtext/inference_mlperf/user5000.conf diff --git a/src/MaxText/inference_utils.py b/src/maxtext/inference_utils.py similarity index 100% rename from src/MaxText/inference_utils.py rename to src/maxtext/inference_utils.py diff --git a/tests/assets/logits_generation/generate_hf_golden_logits.py b/tests/assets/logits_generation/generate_hf_golden_logits.py index e3306fe5aa..f5e3cb9955 100644 --- a/tests/assets/logits_generation/generate_hf_golden_logits.py +++ b/tests/assets/logits_generation/generate_hf_golden_logits.py @@ -47,7 +47,7 @@ import numpy as np from google.cloud import storage from PIL import Image -from MaxText.inference_utils import str2bool +from maxtext.inference_utils import str2bool # Load the tokenizer and model from Hugging Face diff --git a/tests/inference/benchmark_offline_engine.py b/tests/inference/benchmark_offline_engine.py index c503de1b98..be3f4cd745 100644 --- a/tests/inference/benchmark_offline_engine.py +++ b/tests/inference/benchmark_offline_engine.py @@ -31,7 +31,7 @@ from MaxText.globals import MAXTEXT_PKG_DIR from MaxText import max_logging from MaxText import pyconfig -from MaxText.inference.offline_engine import OfflineEngine, InputData, CompletionOutput +from maxtext.inference.offline_engine import OfflineEngine, InputData, CompletionOutput def get_metrics(results: list[CompletionOutput], start_time, end_time): diff --git a/tests/inference/kvcache_test.py b/tests/inference/kvcache_test.py index cbedac20f4..372ce237ca 100644 --- a/tests/inference/kvcache_test.py +++ b/tests/inference/kvcache_test.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE -from MaxText.inference import kvcache +from maxtext.inference import kvcache class MlaKVCacheTest(unittest.TestCase): diff --git a/tests/inference/page_manager_test.py b/tests/inference/page_manager_test.py index 22035c9dde..3ffe76e9d0 100644 --- a/tests/inference/page_manager_test.py +++ b/tests/inference/page_manager_test.py @@ -23,7 +23,7 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.inference.page_manager import PageManager, PageState +from maxtext.inference.page_manager import PageManager, PageState class TestPageManager(unittest.TestCase): diff --git a/tests/inference/test_llama2_7b_bf16.sh b/tests/inference/test_llama2_7b_bf16.sh index 672611932c..5ace9ba8ca 100755 --- a/tests/inference/test_llama2_7b_bf16.sh +++ b/tests/inference/test_llama2_7b_bf16.sh @@ -3,7 +3,7 @@ # Define the arguments in an array args=( "-m" - "MaxText.decode" + "maxtext.decode" "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" "tokenizer_path=${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}/tokenizer.llama2" "model_name=llama2-7b" diff --git a/tests/inference/test_llama2_7b_int8.sh b/tests/inference/test_llama2_7b_int8.sh index 50aa2c0dc9..10b071efdc 100755 --- a/tests/inference/test_llama2_7b_int8.sh +++ b/tests/inference/test_llama2_7b_int8.sh @@ -3,7 +3,7 @@ # Define the arguments in an array args=( "-m" - "MaxText.decode" + "maxtext.decode" "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" "tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/tokenizer.llama2" "model_name=llama2-7b" diff --git a/tests/integration/decode_tests.py b/tests/integration/decode_tests.py index 3cb61b83f3..74d584bce2 100644 --- a/tests/integration/decode_tests.py +++ b/tests/integration/decode_tests.py @@ -23,7 +23,7 @@ from absl.testing import absltest from contextlib import redirect_stdout -from MaxText.decode import main as decode_main +from maxtext.decode import main as decode_main from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT pytestmark = pytest.mark.integration_test diff --git a/tests/integration/generate_param_only_checkpoint_test.py b/tests/integration/generate_param_only_checkpoint_test.py index 72777ed3ce..a89bfa6d6e 100644 --- a/tests/integration/generate_param_only_checkpoint_test.py +++ b/tests/integration/generate_param_only_checkpoint_test.py @@ -22,8 +22,8 @@ from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR from MaxText.train import main as train_main -from MaxText.decode import main as decode_main from MaxText.generate_param_only_checkpoint import main as generate_param_only_ckpt_main +from maxtext.decode import main as decode_main from tests.integration.checkpointing_test import get_checkpointing_command diff --git a/tests/integration/grpo_trainer_correctness_test.py b/tests/integration/grpo_trainer_correctness_test.py index 2548020c88..a92e1f3bff 100644 --- a/tests/integration/grpo_trainer_correctness_test.py +++ b/tests/integration/grpo_trainer_correctness_test.py @@ -47,12 +47,12 @@ from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.experimental.rl.grpo_trainer import grpo_loss_fn, _merge_grpo_state, setup_train_loop from MaxText.experimental.rl.grpo_utils import compute_log_probs -from MaxText.inference import offline_engine from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT, MAXTEXT_TEST_ASSETS_ROOT from MaxText.layers import models from MaxText.layers import quantizations -from MaxText.inference.offline_engine import InputData from MaxText.experimental.rl import grpo_utils +from maxtext.inference import offline_engine +from maxtext.inference.offline_engine import InputData def get_golden_data(config): diff --git a/tests/integration/smoke/inference_microbenchmark_smoke_test.py b/tests/integration/smoke/inference_microbenchmark_smoke_test.py index 3ae010542d..d2f713a855 100644 --- a/tests/integration/smoke/inference_microbenchmark_smoke_test.py +++ b/tests/integration/smoke/inference_microbenchmark_smoke_test.py @@ -21,7 +21,7 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT -from MaxText.inference_microbenchmark import run_benchmarks +from maxtext.inference_microbenchmark import run_benchmarks class Inference_Microbenchmark(unittest.TestCase): diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 085e88989c..ab759d298a 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Test suite for validating MaxText YAML configurations against Pydantic models. +Test suite for validating MaxText YAML configurations against Pydantic models. This test suite uses explicit, hardcoded lists of configuration files grouped by model family (e.g., gemma, llama) to test them directly against the Pydantic `MaxTextConfig` model. It avoids programmatic file discovery and the complex `pyconfig.initialize` function to provide fast, targeted feedback on validation -errors like "Extra inputs are not permitted." +errors like "Extra inputs are not permitted." """ import os @@ -272,7 +272,7 @@ def test_kimi_configs(config_file): os.path.join( MAXTEXT_REPO_ROOT, "src", - "MaxText", + "maxtext", "inference", "configs", "multi_host", @@ -280,13 +280,13 @@ def test_kimi_configs(config_file): "llama3_405b_v6e-16-16.yml", ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "configs", "multi_host", "interleaved", "llama2_70b_v5e-16.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama2_70b_v5e-16.yml" ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "configs", "multi_host", "interleaved", "llama3_70b_v5e-16.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama3_70b_v5e-16.yml" ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "configs", "multi_host", "interleaved", "llama3_405b_v5e-64.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama3_405b_v5e-64.yml" ), ] diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 07dc2c14fd..41485991a6 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -35,13 +35,13 @@ from MaxText import max_utils from MaxText import maxtext_utils from MaxText import sharding -from MaxText import inference_utils from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers import models from MaxText.layers import quantizations from MaxText.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations +from maxtext import inference_utils Transformer = models.transformer_as_linen diff --git a/tests/unit/offline_engine_test.py b/tests/unit/offline_engine_test.py index b72b5c18ef..aa110f4dba 100644 --- a/tests/unit/offline_engine_test.py +++ b/tests/unit/offline_engine_test.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp import numpy as np -from MaxText.inference.offline_engine import OfflineEngine, InputData, CompletionOutput +from maxtext.inference.offline_engine import OfflineEngine, InputData, CompletionOutput from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR diff --git a/tools/orchestration/multihost_job.py b/tools/orchestration/multihost_job.py index 6a2092e326..c622fb3df8 100644 --- a/tools/orchestration/multihost_job.py +++ b/tools/orchestration/multihost_job.py @@ -43,7 +43,7 @@ from datetime import datetime import os import shutil -from MaxText.inference_utils import str2bool +from maxtext.inference_utils import str2bool def get_project():