diff --git a/aten/src/THC/THCTensorTopK.cuh b/aten/src/THC/THCTensorTopK.cuh index 71d1bc98e8e28..773232dba92bf 100644 --- a/aten/src/THC/THCTensorTopK.cuh +++ b/aten/src/THC/THCTensorTopK.cuh @@ -117,7 +117,7 @@ struct TopKTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType convert(at::Half v) { -#if CUDA_VERSION >= 8000 +#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ RadixType x = __half_as_ushort(v); RadixType mask = -((x >> 15)) | 0x8000; return (x ^ mask); @@ -128,7 +128,7 @@ struct TopKTypeConfig { } static inline __device__ at::Half deconvert(RadixType v) { -#if CUDA_VERSION >= 8000 +#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ RadixType mask = ((v >> 15) - 1) | 0x8000; return __ushort_as_half(v ^ mask); #else