Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
# syntax=docker/dockerfile:1-labs
ARG REPO_JAX="https://github.com/google/jax.git"
ARG REPO_XLA="https://github.com/openxla/xla.git"
ARG REPO_TE="https://github.com/NVIDIA/TransformerEngine.git"
ARG REPO_T5X="https://github.com/google-research/t5x.git"
ARG REPO_PAXML="https://github.com/google/paxml.git"
ARG REPO_PRAXIS="https://github.com/google/praxis.git"
ARG REPO_FLAX="https://github.com/google/flax.git"
ARG REF_JAX=main
ARG REF_XLA=main
ARG REF_TE=main
ARG REF_T5X=main
ARG REF_PAXML=main
ARG REF_PRAXIS=main
ARG REF_FLAX=main
ARG SRC_PATH_JAX=/opt/jax-source
ARG SRC_PATH_XLA=/opt/xla-source
ARG SRC_PATH_TE=/opt/transformer-engine
ARG SRC_PATH_T5X=/opt/t5x
ARG SRC_PATH_PAXML=/opt/paxml
ARG SRC_PATH_PRAXIS=/opt/praxis
ARG SRC_PATH_FLAX=/opt/flax
ARG BUILD_DATE

ARG BASE_IMAGE=nvidia/cuda:12.2.0-devel-ubuntu22.04
FROM ${BASE_IMAGE}
FROM ${BASE_IMAGE} as lib-base

###############################################################################
## Install Python
Expand Down Expand Up @@ -50,3 +74,103 @@ ENV PATH=/opt/amazon/efa/bin:${PATH}
###############################################################################

RUN ln -s /opt/nvidia/nsight-compute/*/host/target-linux-x64/nsys /usr/local/cuda/bin

FROM lib-base AS src-base
ARG REPO_JAX
ENV REPO_JAX=${REPO_JAX}
ARG REPO_XLA
ENV REPO_XLA=${REPO_XLA}
ARG REPO_TE
ENV REPO_TE=${REPO_TE}
ARG REPO_T5X
ENV REPO_T5X=${REPO_T5X}
ARG REPO_PAXML
ENV REPO_PAXML=${REPO_PAXML}
ARG REPO_PRAXIS
ENV REPO_PRAXIS=${REPO_PRAXIS}
ARG REPO_FLAX
ENV REPO_FLAX=${REPO_FLAX}
ARG REF_JAX
ENV REF_JAX=${REF_JAX}
ARG REF_XLA
ENV REF_XLA=${REF_XLA}
ARG REF_TE
ENV REF_TE=${REF_TE}
ARG REF_T5X
ENV REF_T5X=${REF_T5X}
ARG REF_PAXML
ENV REF_PAXML=${REF_PAXML}
ARG REF_PRAXIS
ENV REF_PRAXIS=${REF_PRAXIS}
ARG REF_FLAX
ENV REF_FLAX=${REF_FLAX}
ARG SRC_PATH_JAX
ENV SRC_PATH_JAX=${SRC_PATH_JAX}
ARG SRC_PATH_XLA
ENV SRC_PATH_XLA=${SRC_PATH_XLA}
ARG SRC_PATH_TE
ENV SRC_PATH_TE=${SRC_PATH_TE}
ARG SRC_PATH_T5X
ENV SRC_PATH_T5X=${SRC_PATH_T5X}
ARG SRC_PATH_PAXML
ENV SRC_PATH_PAXML=${SRC_PATH_PAXML}
ARG SRC_PATH_PRAXIS
ENV SRC_PATH_PRAXIS=${SRC_PATH_PRAXIS}
ARG SRC_PATH_FLAX
ENV SRC_PATH_FLAX=${SRC_PATH_FLAX}
ARG BUILD_DATE
ENV BUILD_DATE=${BUILD_DATE}

RUN --mount=type=ssh \
--mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \
<<"EOF" bash -e
git clone -b ${REF_JAX} ${REPO_JAX} ${SRC_PATH_JAX}
git clone -b ${REF_XLA} ${REPO_XLA} ${SRC_PATH_XLA}
git clone -b ${REF_TE} ${REPO_TE} ${SRC_PATH_TE}
git -C ${SRC_PATH_TE} submodule update --init --recursive
git clone -b ${REF_T5X} ${REPO_T5X} ${SRC_PATH_T5X}
git clone -b ${REF_PAXML} ${REPO_PAXML} ${SRC_PATH_PAXML}
git clone -b ${REF_PRAXIS} ${REPO_PRAXIS} ${SRC_PATH_PRAXIS}
git clone -b ${REF_FLAX} ${REPO_FLAX} ${SRC_PATH_FLAX}
EOF

RUN pip install --no-cache-dir pip-tools

ADD --link jax-requirements.in /opt/
ADD --link upstream-t5x-requirements.in /opt/
ADD --link upstream-paxml-requirements.in /opt/

WORKDIR /opt
# Note: This will not specify the compiled jaxlib wheel under $REPO_JAX. That needs to be replaced after
# the requirements are compiled. This should be okay since at the time of writing, jaxlib's requirements
# are a subset of jax's
RUN pip-compile jax-requirements.in && rm -rf /root/.cache/pip-tools
RUN pip-compile upstream-t5x-requirements.in && rm -rf /root/.cache/pip-tools
RUN SKIP_HEAD_INSTALLS=true pip-compile upstream-paxml-requirements.in && rm -rf /root/.cache/pip-tools

# Handle head installs by appending their SHAs to the requirement file. Aggregate all of them
# to ensure we specify the same commit if a package is repeated
RUN <<"EOF" bash -e -u -o pipefail
declare -A old_to_new_req
IFS=$'\n'
for head_req in $(cat *requirements.txt | egrep '^[^#].+ @ git\+'); do
pkg_name=$(echo "$head_req" | awk '{print $1}')
pkg_url=$(echo "$head_req" | awk '{sub(/^git\+/,"",$3); print $3}')
old_to_new_req["$head_req"]="${head_req}@$(git ls-remote $pkg_url HEAD | awk '{print $1}')"
done
unset IFS

for req_file in *requirements.txt; do
rm -f "${req_file}.tmp"
while IFS= read -r line; do
# Check if the line should be replaced
if [[ -v old_to_new_req["$line"] ]]; then
echo "${old_to_new_req["$line"]}" >> "${req_file}.tmp"
else
echo "$line" >> "${req_file}.tmp"
fi
done < "${req_file}"
mv "${req_file}.tmp" "${req_file}"
done
EOF

82 changes: 82 additions & 0 deletions .github/container/Dockerfile.fw
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base
ARG BAZEL_CACHE=/tmp
ARG BUILD_DATE

###############################################################################
## Build JAX
###############################################################################

FROM ${BASE_IMAGE} as jax-builder
ARG BAZEL_CACHE

ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
RUN build-jax.sh \
--bazel-cache ${BAZEL_CACHE} \
--src-path-jax ${SRC_PATH_JAX} \
--src-path-xla ${SRC_PATH_XLA} \
--sm all \
--clean && [ -d ${BAZEL_CACHE} ] && rm -r ${BAZEL_CACHE}

# TODO(terry): don't delete /tmp and re-create. Only doing this for now b/c I want to cache ^
RUN mkdir /tmp

FROM jax-builder as runtime-image
# The following environment variables tune performance
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line could cause hangs. Would it be related to our late hangs with Pax?

ENV NCCL_IB_SL=1
ENV NCCL_NVLS_ENABLE=0
ENV NVTE_FRAMEWORK=jax

# TODO: properly configure entrypoint
# COPY entrypoint.d/ /opt/nvidia/entrypoint.d/

FROM jax-builder as jax
# Since jaxlib is not explicitly listed by the jax package itself, we append it
RUN <<"EOF" bash -e
if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then
echo "Expect only 1 jax wheel to be built"
exit 1
fi
echo "jaxlib @ file://${SRC_PATH_JAX}/dist/"$(ls ${SRC_PATH_JAX}/dist/) >> /opt/jax-requirements.txt
EOF

RUN pip-sync /opt/jax-requirements.txt && rm -rf /root/.cache/pip

FROM jax-builder as upstream-t5x

# Since upstream-t5x adds jaxlib from pypi, we need to replace it
RUN <<"EOF" bash -e
if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then
echo "Expect only 1 jax wheel to be built"
exit 1
fi
sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-t5x-requirements.txt
EOF

# T5x uses a flax distribution so we will use the flax source, which should be identical to one we'd pull from VCS in t5x's requirements
RUN sed -i "s|^flax[ =].*|flax @ file://${SRC_PATH_FLAX}|" /opt/upstream-t5x-requirements.txt

# We will also comment out t5x as a requirement since it introduces dependencies like "seqio @ git+${URL}" which are in conflict with the ones we manually add like "seqio @ git+${URL}@$SHA"
RUN sed -i 's/^t5x/#t5x/' /opt/upstream-t5x-requirements.txt

RUN pip-sync /opt/upstream-t5x-requirements.txt && rm -rf /root/.cache/pip
RUN pip install --no-deps --no-cache-dir ${SRC_PATH_T5X}

FROM jax-builder as upstream-pax

# Since upstream-pax adds jaxlib from pypi, we need to replace it
RUN <<"EOF" bash -e
if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then
echo "Expect only 1 jax wheel to be built"
exit 1
fi
sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-paxml-requirements.txt
EOF
# We will also comment out paxml/praxis as a requirement since it introduces dependencies like "fiddle @ git+${URL}" which are in conflict with the ones we manually add like "fiddle @ git+${URL}@$SHA"
RUN sed -i 's/^paxml/#paxml/' /opt/upstream-paxml-requirements.txt
RUN sed -i 's/^praxis/#praxis/' /opt/upstream-paxml-requirements.txt

RUN pip-sync /opt/upstream-paxml-requirements.txt && rm -rf /root/.cache/pip
RUN pip install --no-deps --no-cache-dir ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}
4 changes: 4 additions & 0 deletions .github/container/jax-requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
jax @ file:///opt/jax-source
# TODO: Brittle, may be possible to use --find-links=./wheel_dir instead
#jaxlib @ file:///opt/jax-source/dist/jaxlib-0.4.14-cp310-cp310-linux_x86_64.whl
transformer_engine @ file:///opt/transformer-engine
6 changes: 6 additions & 0 deletions .github/container/upstream-paxml-requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-c jax-requirements.in
/opt/paxml
/opt/praxis
# Re-introduce the dependency that was skipped by SKIP_HEAD_INSTALLS=true
# Omitting jax since it is included by the jax-requirements.in constraint
fiddle @ git+https://github.com/google/fiddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is fiddle pinned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. The logic to pin it is here: https://github.com/NVIDIA/JAX-Toolbox/pull/271/files#diff-8386cfe5c29a3be212a911c9550a56c57420a17cae190b80f60e52c0ce265645R151-R175

fiddle will appear as fiddle @ git+https://github.com/google/fiddle in the *requirements.txt file after pip-compile is run; and after this is run, it will look like fiddle @ git+https://github.com/google/fiddle@$COMMIT_SHA

2 changes: 2 additions & 0 deletions .github/container/upstream-t5x-requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-c jax-requirements.in
/opt/t5x