-
Notifications
You must be signed in to change notification settings - Fork 66
[PoC] Early progress of using pip-compile. Dockerfile.fw builds jax/upstream-t5x/upstream-pax #271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # syntax=docker/dockerfile:1-labs | ||
| ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base | ||
| ARG BAZEL_CACHE=/tmp | ||
| ARG BUILD_DATE | ||
|
|
||
| ############################################################################### | ||
| ## Build JAX | ||
| ############################################################################### | ||
|
|
||
| FROM ${BASE_IMAGE} as jax-builder | ||
| ARG BAZEL_CACHE | ||
|
|
||
| ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ | ||
| RUN build-jax.sh \ | ||
| --bazel-cache ${BAZEL_CACHE} \ | ||
| --src-path-jax ${SRC_PATH_JAX} \ | ||
| --src-path-xla ${SRC_PATH_XLA} \ | ||
| --sm all \ | ||
| --clean && [ -d ${BAZEL_CACHE} ] && rm -r ${BAZEL_CACHE} | ||
|
|
||
| # TODO(terry): don't delete /tmp and re-create. Only doing this for now b/c I want to cache ^ | ||
| RUN mkdir /tmp | ||
|
|
||
| FROM jax-builder as runtime-image | ||
| # The following environment variables tune performance | ||
| ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false" | ||
| ENV CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
| ENV NCCL_IB_SL=1 | ||
| ENV NCCL_NVLS_ENABLE=0 | ||
| ENV NVTE_FRAMEWORK=jax | ||
|
|
||
| # TODO: properly configure entrypoint | ||
| # COPY entrypoint.d/ /opt/nvidia/entrypoint.d/ | ||
|
|
||
| FROM jax-builder as jax | ||
| # Since jaxlib is not explicitly listed by the jax package itself, we append it | ||
| RUN <<"EOF" bash -e | ||
| if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then | ||
| echo "Expect only 1 jax wheel to be built" | ||
| exit 1 | ||
| fi | ||
| echo "jaxlib @ file://${SRC_PATH_JAX}/dist/"$(ls ${SRC_PATH_JAX}/dist/) >> /opt/jax-requirements.txt | ||
| EOF | ||
|
|
||
| RUN pip-sync /opt/jax-requirements.txt && rm -rf /root/.cache/pip | ||
|
|
||
| FROM jax-builder as upstream-t5x | ||
|
|
||
| # Since upstream-t5x adds jaxlib from pypi, we need to replace it | ||
| RUN <<"EOF" bash -e | ||
| if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then | ||
| echo "Expect only 1 jax wheel to be built" | ||
| exit 1 | ||
| fi | ||
| sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-t5x-requirements.txt | ||
| EOF | ||
|
|
||
| # T5x uses a flax distribution so we will use the flax source, which should be identical to one we'd pull from VCS in t5x's requirements | ||
| RUN sed -i "s|^flax[ =].*|flax @ file://${SRC_PATH_FLAX}|" /opt/upstream-t5x-requirements.txt | ||
|
|
||
| # We will also comment out t5x as a requirement since it introduces dependencies like "seqio @ git+${URL}" which are in conflict with the ones we manually add like "seqio @ git+${URL}@$SHA" | ||
| RUN sed -i 's/^t5x/#t5x/' /opt/upstream-t5x-requirements.txt | ||
|
|
||
| RUN pip-sync /opt/upstream-t5x-requirements.txt && rm -rf /root/.cache/pip | ||
| RUN pip install --no-deps --no-cache-dir ${SRC_PATH_T5X} | ||
|
|
||
| FROM jax-builder as upstream-pax | ||
|
|
||
| # Since upstream-pax adds jaxlib from pypi, we need to replace it | ||
| RUN <<"EOF" bash -e | ||
| if [[ $(ls ${SRC_PATH_JAX}/dist/ | wc -l) -ne 1 ]]; then | ||
| echo "Expect only 1 jax wheel to be built" | ||
| exit 1 | ||
| fi | ||
| sed -i "s|^jaxlib[ =].*|jaxlib @ file://${SRC_PATH_JAX}/dist/$(ls ${SRC_PATH_JAX}/dist/)|" /opt/upstream-paxml-requirements.txt | ||
| EOF | ||
| # We will also comment out paxml/praxis as a requirement since it introduces dependencies like "fiddle @ git+${URL}" which are in conflict with the ones we manually add like "fiddle @ git+${URL}@$SHA" | ||
| RUN sed -i 's/^paxml/#paxml/' /opt/upstream-paxml-requirements.txt | ||
| RUN sed -i 's/^praxis/#praxis/' /opt/upstream-paxml-requirements.txt | ||
|
|
||
| RUN pip-sync /opt/upstream-paxml-requirements.txt && rm -rf /root/.cache/pip | ||
| RUN pip install --no-deps --no-cache-dir ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| jax @ file:///opt/jax-source | ||
| # TODO: Brittle, may be possible to use --find-links=./wheel_dir instead | ||
| #jaxlib @ file:///opt/jax-source/dist/jaxlib-0.4.14-cp310-cp310-linux_x86_64.whl | ||
| transformer_engine @ file:///opt/transformer-engine |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| -c jax-requirements.in | ||
| /opt/paxml | ||
| /opt/praxis | ||
| # Re-introduce the dependency that was skipped by SKIP_HEAD_INSTALLS=true | ||
| # Omitting jax since it is included by the jax-requirements.in constraint | ||
| fiddle @ git+https://github.com/google/fiddle | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is fiddle pinned?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is. The logic to pin it is here: https://github.com/NVIDIA/JAX-Toolbox/pull/271/files#diff-8386cfe5c29a3be212a911c9550a56c57420a17cae190b80f60e52c0ce265645R151-R175 fiddle will appear as |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| -c jax-requirements.in | ||
| /opt/t5x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line could cause hangs. Would it be related to our late hangs with Pax?