From 79dffe5c118f9cdd1eccef15af677d7fb5849e58 Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Wed, 25 Nov 2020 18:00:07 +0000 Subject: [PATCH] Upgrade JAX to 0.2.6 http://b/174243372 --- gpu.Dockerfile | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/gpu.Dockerfile b/gpu.Dockerfile index 2c0f2556..cb634c1d 100644 --- a/gpu.Dockerfile +++ b/gpu.Dockerfile @@ -80,16 +80,7 @@ RUN pip uninstall -y lightgbm && \ /tmp/clean-layer.sh # Install JAX -# b/154150582#comment9: JAX 0.1.63 with jaxlib 0.1.43 is causing the GPU tests to hang. -ENV JAX_VERSION=0.1.62 -ENV JAXLIB_VERSION=0.1.41 -ENV JAX_PYTHON_VERSION=cp37 -ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -ENV JAX_PLATFORM=linux_x86_64 -ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases" - -RUN pip install $JAX_BASE_URL/$JAX_CUDA_VERSION/jaxlib-$JAXLIB_VERSION-$JAX_PYTHON_VERSION-none-$JAX_PLATFORM.whl && \ - pip install jax==$JAX_VERSION && \ +RUN pip install jax==0.2.6 jaxlib==0.1.57+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \ /tmp/clean-layer.sh # Reinstall packages with a separate version for GPU support.