From 3bf9e4589ec850d0b413c24ac48c6dc9c33336a8 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Thu, 24 Apr 2025 15:33:56 +0000 Subject: [PATCH] Fix TE build broken by #1409 --- .github/container/Dockerfile.jax | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index c8ad8721b..330411bc5 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -59,10 +59,12 @@ RUN build-jax.sh \ ## Transformer engine: check out source and build wheel RUN <<"EOF" bash -ex -o pipefail pip install ninja && rm -rf ~/.cache/pip -# TransformerEngine now needs JAX at build time git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} pushd ${SRC_PATH_TRANSFORMER_ENGINE} export NVTE_BUILD_THREADS_PER_JOB=8 +export NVTE_FRAMEWORK=jax +# TransformerEngine needs FFI headers from XLA +export XLA_HOME=${SRC_PATH_XLA} python setup.py bdist_wheel && rm -rf build ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" EOF @@ -114,7 +116,6 @@ echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in EOF # Copy TransformerEngine wheel from the builder stage -ENV NVTE_FRAMEWORK=jax ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE} COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} RUN <<"EOF" bash -ex