From 9bff377962c8716f243aa42c9572a967aa1edd72 Mon Sep 17 00:00:00 2001 From: Jae Song Date: Wed, 7 May 2025 21:02:38 +0000 Subject: [PATCH 1/2] Remove CUDA_VERSION env variable and deprecated apt-key --- Dockerfile.tmpl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index 144672ad..0d508a13 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -81,17 +81,9 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo && \ uv pip install --system /tmp/lightgbm/*.whl && \ rm -rf /tmp/lightgbm && \ /tmp/clean-layer.sh - -# Remove CUDA_VERSION from non-GPU image. -{{ else }} -ENV CUDA_VERSION="" {{ end }} -# Update GPG key per documentation at https://cloud.google.com/compute/docs/troubleshooting/known-issues -RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - -RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - - # Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections, # as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346 RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \ From 806bbc4dc599d682251f29336d832d15a70513e9 Mon Sep 17 00:00:00 2001 From: Jae Song Date: Wed, 7 May 2025 22:22:19 +0000 Subject: [PATCH 2/2] Add isGPU function instead of env variable --- tests/common.py | 5 ++++- tests/test_jax.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/common.py b/tests/common.py index 30a7bb0f..469033dd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -11,7 +11,10 @@ def getAcceleratorName(): except FileNotFoundError: return("nvidia-smi not found.") -gpu_test = unittest.skipIf(len(os.environ.get('CUDA_VERSION', '')) == 0, 'Not running GPU tests') +def isGPU(): + return os.path.isfile('/proc/driver/nvidia/version') + +gpu_test = unittest.skipIf(not isGPU(), 'Not running GPU tests') # b/342143152 P100s are slowly being unsupported in new release of popular ml tools such as RAPIDS. p100_exempt = unittest.skipIf(getAcceleratorName() == "Tesla P100-PCIE-16GB", 'Not running p100 exempt tests') tpu_test = unittest.skipIf(len(os.environ.get('ISTPUVM', '')) == 0, 'Not running TPU tests') diff --git a/tests/test_jax.py b/tests/test_jax.py index b5e0898e..f8eca3bb 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -6,7 +6,7 @@ import jax import jax.numpy as np -from common import gpu_test +from common import gpu_test, isGPU from jax import grad, jit @@ -21,5 +21,5 @@ def test_grad(self): self.assertEqual(0.4199743, ag) def test_backend(self): - expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu' + expected_backend = 'cpu' if not isGPU() else 'gpu' self.assertEqual(expected_backend, jax.default_backend())