diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 581525e970ce..ad5797e0cc3c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ee6b715adaef..babd15bb30fb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -40,7 +40,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -59,11 +59,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -220,7 +220,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -232,10 +232,10 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -320,7 +320,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -486,7 +486,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -496,10 +496,10 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -534,7 +534,7 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f0..ff64c4bd8f80 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 5786f77f7bca..7f2b29de0d93 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + // See https://github.com/RadeonOpenCompute/ROCm/issues/2534 + hip_bfloat16 A, B; + __hip_bfloat16 c; + A.data = a.data; + B.data = b.data; + c.data = (A + B).data; + return c; + #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52..b9c9275aae3f 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -64,20 +68,46 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { uint32_t b; +#ifndef USE_ROCM asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + b = tmp.u32; +#endif return b; } inline __device__ float half_to_float(uint16_t h) { float f; +#ifndef USE_ROCM asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif return f; } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif } inline __device__ uint16_t float_to_half(float f) { @@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) { uint32_t u32; uint16_t u16[2]; } tmp; +#ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif return tmp.u16[0]; } @@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) { uint32_t u32; uint16_t u16[2]; } tmp; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); #endif return tmp.u32; } @@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) { // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { uint16_t c; +#ifndef USE_ROCM asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif return c; } inline __device__ uint32_t add(uint32_t a, uint32_t b) { uint32_t c; +#ifndef USE_ROCM asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif return c; } @@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; +#ifndef USE_ROCM asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif return c; } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; +#ifndef USE_ROCM asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif return c; } @@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; +#ifndef USE_ROCM asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif return d; } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 4c8068281229..c3830df2f095 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 000000000000..8991462a862e --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,19 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99..2439f5922a3f 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,7 @@ +#ifdef USE_ROCM + #include +#endif + int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 41001ba64746..474509441563 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b5..b95ccef16207 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -17,13 +17,15 @@ */ #pragma once +#include "cuda_compat.h" + namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } diff --git a/setup.py b/setup.py index 6ffc03c25386..55a4358f734f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -24,10 +24,14 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -76,66 +80,72 @@ def get_torch_arch_list() -> Set[str]: f"{valid_archs}.") return arch_list - -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() +def get_cuda_compute_capabilities(nvcc_cuda_version): + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + + return compute_capabilities + +def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") + if any(cc.startswith("8.6") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.1 or higher is required for compute capability 8.6.") if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("11.1"): - if any(cc.startswith("8.6") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.") - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.") + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + compute_capabilities = get_cuda_compute_capabilities(nvcc_cuda_version) + validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities) -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] @@ -206,7 +216,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) +if not torch.version.hip: + ext_modules.append(quantization_extension) # Misc. CUDA utils. cuda_utils_extension = CUDAExtension(