Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions experimental/jetstream-maxtext-stable-stack/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


FROM alpine/git:2.47.2 AS maxtext_cloner

ARG MAXTEXT_COMMIT_HASH

WORKDIR /src

RUN \
git clone --depth=1 https://github.com/AI-Hypercomputer/maxtext.git && \
if [ -n "${MAXTEXT_COMMIT_HASH}" ]; then \
cd maxtext && \
git fetch origin ${MAXTEXT_COMMIT_HASH} && \
git switch --detach ${MAXTEXT_COMMIT_HASH}; \
fi


FROM alpine/git:2.47.2 AS jetstream_cloner

ARG JETSTREAM_COMMIT_HASH

WORKDIR /src
RUN \
git clone --depth=1 https://github.com/AI-Hypercomputer/JetStream.git && \
if [ -n "${JETSTREAM_COMMIT_HASH}" ]; then \
cd JetStream && \
git fetch origin ${JETSTREAM_COMMIT_HASH} && \
git switch --detach ${JETSTREAM_COMMIT_HASH}; \
fi

FROM python:3.10-slim-bullseye AS runner

WORKDIR /jetstream_maxtext_stable_stack

# Environment variable for no-cache-dir and pip root user warning
ENV PIP_NO_CACHE_DIR=1
ENV PIP_ROOT_USER_ACTION=ignore

# Set environment variables for Google Cloud SDK and Python 3.10
ENV PYTHON_VERSION=3.10
ENV CLOUD_SDK_VERSION=latest

# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors
ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update \
&& \
apt-get install -y --no-install-recommends git git-lfs \
&& \
rm -rf /var/lib/apt/lists/*

RUN python3 -m pip install --upgrade pip

# Install MaxText package
COPY --from=maxtext_cloner /src .
RUN cd maxtext && bash setup.sh

# MaxText install jetstream from the main. Need overwrite it.
# Install JetStream requirements
COPY --from=jetstream_cloner /src .
RUN python3 -m pip install ./JetStream
RUN python3 -m pip install -r ./JetStream/benchmarks/requirements.in

COPY generate_manifest.sh .
RUN \
bash ./generate_manifest.sh \
PREFIX=jetstream_maxtext \
MAXTEXT_COMMIT_HASH=$(git -C ./maxtext rev-parse HEAD) \
JETSTREAM_COMMIT_HASH=$(git -C ./JetStream rev-parse HEAD)
58 changes: 58 additions & 0 deletions experimental/jetstream-maxtext-stable-stack/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Jetstream MaxText Stable Stack

This provides a stable Docker image stack for running MaxText using JetStream on Cloud TPUs for inference.

## Overview

The goal of this project is to offer a reliable and up-to-date environment for deploying and serving MaxText efficiently on TPU hardware via the JetStream inference server.

## Getting Started

### Prerequisites

- Docker installed on your machine or VM.
- Access to Google Cloud Platform and authenticated `gcloud` CLI (if pulling from GCR).
- Access to TPU resources configured for your project.

### Pulling the Image

The stable stack is available as a nightly Docker image hosted on Google Container Registry (GCR). To pull the latest nightly image, replace `YYYYMMDD` with the desired date (e.g., `20231027`):

```bash
# Replace YYYYMMDD with the specific date, e.g., 20231027
export NIGHTLY_DATE=$(date +"%Y%m%d") # Or set manually, e.g., export NIGHTLY_DATE=20231027

docker pull gcr.io/cloud-tpu-inference-test/jetstream-maxtext-stable-stack/tpu:nightly-${NIGHTLY_DATE}

# Or the last nightly build
docker pull gcr.io/cloud-tpu-inference-test/jetstream-maxtext-stable-stack/tpu:nightly
```

## Running the Container

Run on the TPU VM.

```bash
docker run --net=host --privileged --rm -it \
# Add necessary volume mounts, TPU device access, network ports, etc.
gcr.io/cloud-tpu-inference-test/jetstream-maxtext-stable-stack/tpu:nightly \
bash
```

## Image Information

- Registry: Google Container Registry (GCR)
- Path: gcr.io/cloud-tpu-inference-test/jetstream-maxtext-stable-stack/tpu
- Tagging Scheme: nightly-YYYYMMDD (e.g., nightly-20231027)

A new image is built nightly, incorporating the latest updates and dependencies for the JetStream-MaxText stack on TPUs. Use the tag corresponding to the date you wish to use.

## Build the Image

- build.sh build the local docker image
- test.sh test all the .sh in test_script using the built image
- pipeline.sh build, test and upload the image if all success.

```bash
./pipeline.sh UPLOAD_IMAGE_TAG=gcr.io/cloud-tpu-inference-test/jetstream-maxtext-stable-stack/tpu:nightly-$(date +"%Y%m%d")
```
40 changes: 40 additions & 0 deletions experimental/jetstream-maxtext-stable-stack/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -xe

export LOCAL_IMAGE_TAG=${LOCAL_IMAGE_TAG:-jetstream-maxtext-stable-stack:latest}
export MAXTEXT_COMMIT_HASH=${MAXTEXT_COMMIT_HASH}
export JETSTREAM_COMMIT_HASH=${JETSTREAM_COMMIT_HASH}

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done


if [[ -z "$LOCAL_IMAGE_TAG" ]]; then
echo -e "\n\nError: You must specify an LOCAL_IMAGE_TAG.\n\n"
exit 1
fi

docker build --no-cache \
--build-arg MAXTEXT_COMMIT_HASH=${MAXTEXT_COMMIT_HASH} \
--build-arg JETSTREAM_COMMIT_HASH="${JETSTREAM_COMMIT_HASH}" \
-t ${LOCAL_IMAGE_TAG} \
-f ./Dockerfile .

echo "********* Sucessfully built Stable Stack Image with tag $LOCAL_IMAGE_TAG *********"
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This script generates a manifest of currently installed Python packages, along with their versions.
# The manifest is named with a timestamp for easy versioning and tracking.

export PREFIX='default'

for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
echo "$KEY"="$VALUE"
done

# Set the Manifest file name with the date for versioning
TIMESTAMP=$(date +"%Y%m%d-%H%M%S")
MANIFEST_FILE="${PREFIX}_manifest_${TIMESTAMP}.txt"

# Freeze packages installed and their version to the Manifest file, with sorted and commented Manifest
pip freeze | sort > "$MANIFEST_FILE"

# Maxtext depend on main branch of jetstream we don't want.
# Remove google-jetstream from the Manifest file
grep -vE '^google-jetstream(==|>=|<=|>|<| |@|$)' "$MANIFEST_FILE" > temp && mv temp "$MANIFEST_FILE"

# Write commit details to the Manifest file
if [[ -n "$MAXTEXT_COMMIT_HASH" ]]; then
echo "# maxtext commit hash: $MAXTEXT_COMMIT_HASH" | cat - "$MANIFEST_FILE" > temp && mv temp "$MANIFEST_FILE"
fi
if [[ -n "$JETSTREAM_COMMIT_HASH" ]]; then
echo "# JetStream commit hash: $JETSTREAM_COMMIT_HASH" | cat - "$MANIFEST_FILE" > temp && mv temp "$MANIFEST_FILE"
fi

# Add a header comment to the Manifest file
echo "# Python Packages Frozen at: ${TIMESTAMP}" | cat - "$MANIFEST_FILE" > temp && mv temp "$MANIFEST_FILE"
48 changes: 48 additions & 0 deletions experimental/jetstream-maxtext-stable-stack/pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -xe

export LOCAL_IMAGE_TAG="jetstream-maxtext-stable-stack:nightly"
export MAXTEXT_COMMIT_HASH=""
export JETSTREAM_COMMIT_HASH=""
export UPLOAD_IMAGE_TAG=""

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

if [[ -z "$UPLOAD_IMAGE_TAG" ]]; then
echo -e "\n\nError: You must specify an UPLOAD_IMAGE_TAG.\n\n"
exit 1
fi


docker_image_upload()
{
local nightly_tag=${UPLOAD_IMAGE_TAG%:*}:nightly
docker tag ${LOCAL_IMAGE_TAG} ${UPLOAD_IMAGE_TAG}
docker tag ${LOCAL_IMAGE_TAG} ${nightly_tag}
docker push ${UPLOAD_IMAGE_TAG}
docker push ${nightly_tag}
echo "All done, check out your artifacts at: ${UPLOAD_IMAGE_TAG}"
}

gcloud auth configure-docker us-docker.pkg.dev --quiet
./build.sh
./test.sh
docker_image_upload
67 changes: 67 additions & 0 deletions experimental/jetstream-maxtext-stable-stack/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Docker image name to use for executing test scripts
export LOCAL_IMAGE_TAG=${LOCAL_IMAGE_TAG}

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

echo "--- Starting test execution ---"

shopt -s nullglob
test_script_files=(test_script/*.sh)
shopt -u nullglob

echo "Found the following test scripts:"
printf " - %s\n" "${test_script_files[@]}"

declare -a failed_scripts
overall_exit_status=0

for script_path in "${test_script_files[@]}"; do
if [[ -f "$script_path" ]]; then
echo ">>> Running test script: $script_path"

docker run --net=host --privileged --rm -i ${LOCAL_IMAGE_TAG} bash < "$script_path"
script_exit_status=$? # Capture the exit code of the docker run command

if [[ $script_exit_status -ne 0 ]]; then
echo "<<< FAILED test script: $script_path (Exit Code: $script_exit_status)"
failed_scripts+=("$script_path")
overall_exit_status=1
else
echo "<<< Finished test script successfully: $script_path"
fi
echo
else
echo "--- Skipping non-file entry: $script_path ---"
fi
done

echo

if [[ $overall_exit_status -ne 0 ]]; then
echo "--- Test Execution Summary: FAILURES DETECTED ---"
echo "The following scripts failed:"
printf " - %s\n" "${failed_scripts[@]}"
exit 1
else
echo "--- Test Execution Summary: All tests passed successfully ---"
exit 0
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
cd maxtext

LIBTPU_INIT_ARGS="--xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000" \
python -m MaxText.benchmark_chunked_prefill \
MaxText/configs/inference.yml \
tokenizer_path=assets/tokenizer.mistral-v1 \
max_prefill_predict_length=8192 \
max_target_length=8704 \
model_name=mixtral-8x7b \
ici_fsdp_parallelism=1 \
ici_autoregressive_parallelism=1 \
ici_tensor_parallelism=8 \
scan_layers=false \
weight_dtype=bfloat16 \
per_device_batch_size=8 \
megablox=False \
quantization=int8 \
quantize_kvcache=False \
checkpoint_is_quantized=True \
capacity_factor=1 \
attention=dot_product \
model_call_mode=inference \
sparse_matmul=False \
use_chunked_prefill=true \
prefill_chunk_size=2048

Loading