From 56529098eee8c1fe413dd031fd91ece93c3baf3a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 28 Sep 2023 17:22:34 -0700 Subject: [PATCH] early progress of using pip-compile. Dockerfile.fw builds jax/upstream-t5x/upstream-pax --- .github/container/Dockerfile.base | 126 +++++++++++++++++- .github/container/Dockerfile.fw | 82 ++++++++++++ .github/container/jax-requirements.in | 4 + .../container/upstream-paxml-requirements.in | 6 + .../container/upstream-t5x-requirements.in | 2 + 5 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 .github/container/Dockerfile.fw create mode 100644 .github/container/jax-requirements.in create mode 100644 .github/container/upstream-paxml-requirements.in create mode 100644 .github/container/upstream-t5x-requirements.in diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index bb1590a63..2c8e4b4e2 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -1,5 +1,29 @@ +# syntax=docker/dockerfile:1-labs +ARG REPO_JAX="https://github.com/google/jax.git" +ARG REPO_XLA="https://github.com/openxla/xla.git" +ARG REPO_TE="https://github.com/NVIDIA/TransformerEngine.git" +ARG REPO_T5X="https://github.com/google-research/t5x.git" +ARG REPO_PAXML="https://github.com/google/paxml.git" +ARG REPO_PRAXIS="https://github.com/google/praxis.git" +ARG REPO_FLAX="https://github.com/google/flax.git" +ARG REF_JAX=main +ARG REF_XLA=main +ARG REF_TE=main +ARG REF_T5X=main +ARG REF_PAXML=main +ARG REF_PRAXIS=main +ARG REF_FLAX=main +ARG SRC_PATH_JAX=/opt/jax-source +ARG SRC_PATH_XLA=/opt/xla-source +ARG SRC_PATH_TE=/opt/transformer-engine +ARG SRC_PATH_T5X=/opt/t5x +ARG SRC_PATH_PAXML=/opt/paxml +ARG SRC_PATH_PRAXIS=/opt/praxis +ARG SRC_PATH_FLAX=/opt/flax +ARG BUILD_DATE + ARG BASE_IMAGE=nvidia/cuda:12.2.0-devel-ubuntu22.04 -FROM ${BASE_IMAGE} +FROM ${BASE_IMAGE} as lib-base ############################################################################### ## Install Python @@ -50,3 +74,103 @@ ENV PATH=/opt/amazon/efa/bin:${PATH} ############################################################################### RUN ln -s /opt/nvidia/nsight-compute/*/host/target-linux-x64/nsys /usr/local/cuda/bin + +FROM lib-base AS src-base +ARG REPO_JAX +ENV REPO_JAX=${REPO_JAX} +ARG REPO_XLA +ENV REPO_XLA=${REPO_XLA} +ARG REPO_TE +ENV REPO_TE=${REPO_TE} +ARG REPO_T5X +ENV REPO_T5X=${REPO_T5X} +ARG REPO_PAXML +ENV REPO_PAXML=${REPO_PAXML} +ARG REPO_PRAXIS +ENV REPO_PRAXIS=${REPO_PRAXIS} +ARG REPO_FLAX +ENV REPO_FLAX=${REPO_FLAX} +ARG REF_JAX +ENV REF_JAX=${REF_JAX} +ARG REF_XLA +ENV REF_XLA=${REF_XLA} +ARG REF_TE +ENV REF_TE=${REF_TE} +ARG REF_T5X +ENV REF_T5X=${REF_T5X} +ARG REF_PAXML +ENV REF_PAXML=${REF_PAXML} +ARG REF_PRAXIS +ENV REF_PRAXIS=${REF_PRAXIS} +ARG REF_FLAX +ENV REF_FLAX=${REF_FLAX} +ARG SRC_PATH_JAX +ENV SRC_PATH_JAX=${SRC_PATH_JAX} +ARG SRC_PATH_XLA +ENV SRC_PATH_XLA=${SRC_PATH_XLA} +ARG SRC_PATH_TE +ENV SRC_PATH_TE=${SRC_PATH_TE} +ARG SRC_PATH_T5X +ENV SRC_PATH_T5X=${SRC_PATH_T5X} +ARG SRC_PATH_PAXML +ENV SRC_PATH_PAXML=${SRC_PATH_PAXML} +ARG SRC_PATH_PRAXIS +ENV SRC_PATH_PRAXIS=${SRC_PATH_PRAXIS} +ARG SRC_PATH_FLAX +ENV SRC_PATH_FLAX=${SRC_PATH_FLAX} +ARG BUILD_DATE +ENV BUILD_DATE=${BUILD_DATE} + +RUN --mount=type=ssh \ + --mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \ + <<"EOF" bash -e +git clone -b ${REF_JAX} ${REPO_JAX} ${SRC_PATH_JAX} +git clone -b ${REF_XLA} ${REPO_XLA} ${SRC_PATH_XLA} +git clone -b ${REF_TE} ${REPO_TE} ${SRC_PATH_TE} +git -C ${SRC_PATH_TE} submodule update --init --recursive +git clone -b ${REF_T5X} ${REPO_T5X} ${SRC_PATH_T5X} +git clone -b ${REF_PAXML} ${REPO_PAXML} ${SRC_PATH_PAXML} +git clone -b ${REF_PRAXIS} ${REPO_PRAXIS} ${SRC_PATH_PRAXIS} +git clone -b ${REF_FLAX} ${REPO_FLAX} ${SRC_PATH_FLAX} +EOF + +RUN pip install --no-cache-dir pip-tools + +ADD --link jax-requirements.in /opt/ +ADD --link upstream-t5x-requirements.in /opt/ +ADD --link upstream-paxml-requirements.in /opt/ + +WORKDIR /opt +# Note: This will not specify the compiled jaxlib wheel under $REPO_JAX. That needs to be replaced after +# the requirements are compiled. This should be okay since at the time of writing, jaxlib's requirements +# are a subset of jax's +RUN pip-compile jax-requirements.in && rm -rf /root/.cache/pip-tools +RUN pip-compile upstream-t5x-requirements.in && rm -rf /root/.cache/pip-tools +RUN SKIP_HEAD_INSTALLS=true pip-compile upstream-paxml-requirements.in && rm -rf /root/.cache/pip-tools + +# Handle head installs by appending their SHAs to the requirement file. Aggregate all of them +# to ensure we specify the same commit if a package is repeated +RUN <<"EOF" bash -e -u -o pipefail +declare -A old_to_new_req +IFS=$'\n' +for head_req in $(cat *requirements.txt | egrep '^[^#].+ @ git\+'); do + pkg_name=$(echo "$head_req" | awk '{print $1}') + pkg_url=$(echo "$head_req" | awk '{sub(/^git\+/,"",$3); print $3}') + old_to_new_req["$head_req"]="${head_req}@$(git ls-remote $pkg_url HEAD | awk '{print $1}')" +done +unset IFS + +for req_file in *requirements.txt; do + rm -f "${req_file}.tmp" + while IFS= read -r line; do + # Check if the line should be replaced + if [[ -v old_to_new_req["$line"] ]]; then + echo "${old_to_new_req["$line"]}" >> "${req_file}.tmp" + else + echo "$line" >> "${req_file}.tmp" + fi + done < "${req_file}" + mv "${req_file}.tmp" "${req_file}" +done +EOF + diff --git a/.github/container/Dockerfile.fw b/.github/container/Dockerfile.fw new file mode 100644 index 000000000..00cd40fff --- /dev/null +++ b/.github/container/Dockerfile.fw @@ -0,0 +1,82 @@ +# syntax=docker/dockerfile:1-labs +ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base +ARG BAZEL_CACHE=/tmp +ARG BUILD_DATE + +############################################################################### +## Build JAX +############################################################################### + +FROM ${BASE_IMAGE} as jax-builder +ARG BAZEL_CACHE + +ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ +RUN build-jax.sh \ + --bazel-cache ${BAZEL_CACHE} \ + --src-path-jax ${SRC_PATH_JAX} \ + --src-path-xla ${SRC_PATH_XLA} \ + --sm all \ + --clean && [ -d ${BAZEL_CACHE} ] && rm -r ${BAZEL_CACHE} + +# TODO(terry): don't delete /tmp and re-create. Only doing this for now b/c I want to cache ^ +RUN mkdir /tmp + +FROM jax-builder as runtime-image +# The following environment variables tune performance +ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false" +ENV CUDA_DEVICE_MAX_CONNECTIONS=1 +ENV NCCL_IB_SL=1 +ENV NCCL_NVLS_ENABLE=0 +ENV NVTE_FRAMEWORK=jax + +# TODO: properly configure entrypoint +# COPY entrypoint.d/ /opt/nvidia/entrypoint.d/ + +FROM jax-builder as jax +# Since jaxlib is not explicitly listed by the jax package itself, we append it +RUN <<"EOF" bash -e +if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then + echo "Expect only 1 jax wheel to be built" + exit 1 +fi +echo "jaxlib @ file://${SRC_PATH_JAX}/dist/"$(ls ${SRC_PATH_JAX}/dist/) >> /opt/jax-requirements.txt +EOF + +RUN pip-sync /opt/jax-requirements.txt && rm -rf /root/.cache/pip + +FROM jax-builder as upstream-t5x + +# Since upstream-t5x adds jaxlib from pypi, we need to replace it +RUN <<"EOF" bash -e +if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then + echo "Expect only 1 jax wheel to be built" + exit 1 +fi +sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-t5x-requirements.txt +EOF + +# T5x uses a flax distribution so we will use the flax source, which should be identical to one we'd pull from VCS in t5x's requirements +RUN sed -i "s|^flax[ =].*|flax @ file://${SRC_PATH_FLAX}|" /opt/upstream-t5x-requirements.txt + +# We will also comment out t5x as a requirement since it introduces dependencies like "seqio @ git+${URL}" which are in conflict with the ones we manually add like "seqio @ git+${URL}@$SHA" +RUN sed -i 's/^t5x/#t5x/' /opt/upstream-t5x-requirements.txt + +RUN pip-sync /opt/upstream-t5x-requirements.txt && rm -rf /root/.cache/pip +RUN pip install --no-deps --no-cache-dir ${SRC_PATH_T5X} + +FROM jax-builder as upstream-pax + +# Since upstream-pax adds jaxlib from pypi, we need to replace it +RUN <<"EOF" bash -e +if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then + echo "Expect only 1 jax wheel to be built" + exit 1 +fi +sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-paxml-requirements.txt +EOF +# We will also comment out paxml/praxis as a requirement since it introduces dependencies like "fiddle @ git+${URL}" which are in conflict with the ones we manually add like "fiddle @ git+${URL}@$SHA" +RUN sed -i 's/^paxml/#paxml/' /opt/upstream-paxml-requirements.txt +RUN sed -i 's/^praxis/#praxis/' /opt/upstream-paxml-requirements.txt + +RUN pip-sync /opt/upstream-paxml-requirements.txt && rm -rf /root/.cache/pip +RUN pip install --no-deps --no-cache-dir ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS} diff --git a/.github/container/jax-requirements.in b/.github/container/jax-requirements.in new file mode 100644 index 000000000..d81931ed2 --- /dev/null +++ b/.github/container/jax-requirements.in @@ -0,0 +1,4 @@ +jax @ file:///opt/jax-source +# TODO: Brittle, may be possible to use --find-links=./wheel_dir instead +#jaxlib @ file:///opt/jax-source/dist/jaxlib-0.4.14-cp310-cp310-linux_x86_64.whl +transformer_engine @ file:///opt/transformer-engine diff --git a/.github/container/upstream-paxml-requirements.in b/.github/container/upstream-paxml-requirements.in new file mode 100644 index 000000000..024d371b8 --- /dev/null +++ b/.github/container/upstream-paxml-requirements.in @@ -0,0 +1,6 @@ +-c jax-requirements.in +/opt/paxml +/opt/praxis +# Re-introduce the dependency that was skipped by SKIP_HEAD_INSTALLS=true +# Omitting jax since it is included by the jax-requirements.in constraint +fiddle @ git+https://github.com/google/fiddle diff --git a/.github/container/upstream-t5x-requirements.in b/.github/container/upstream-t5x-requirements.in new file mode 100644 index 000000000..d1e2dc72b --- /dev/null +++ b/.github/container/upstream-t5x-requirements.in @@ -0,0 +1,2 @@ +-c jax-requirements.in +/opt/t5x