diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index 278221c56..8260b7dbf 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -11,9 +11,9 @@ print_var() { supported_compute_capabilities() { ARCH=$1 if [[ "${ARCH}" == "amd64" ]]; then - echo "5.2,6.0,6.1,7.0,7.5,8.0,8.6,8.9,9.0,9.0a,10.0,10.0a" + echo "sm_75,sm_80,sm_86,sm_90,sm_100,compute_120" elif [[ "${ARCH}" == "arm64" ]]; then - echo "5.3,6.2,7.0,7.2,7.5,8.0,8.6,8.7,8.9,9.0,9.0a,10.0,10.0a" + echo "sm_80,sm_86,sm_90,sm_100,compute_120" else echo "Invalid arch '$ARCH' (expected 'amd64' or 'arm64')" 1>&2 return 1