From e94714680bdf3b79acbe119e757b3f897acadf28 Mon Sep 17 00:00:00 2001 From: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Date: Wed, 13 Aug 2025 13:02:01 -0500 Subject: [PATCH] [rocm7.0_internal_testing] Define uint32 t when ROCM_VERSION >= 70000 (#2502) Fixes SWDEV-543698 (https://ontrack-internal.amd.com/browse/SWDEV-543698) This PR fixes the errors like below: ``` [rank3]: RuntimeError: The following operation failed in the TorchScript interpreter. [rank3]: Traceback of TorchScript (most recent call last): [rank3]: RuntimeError: /tmp/comgr-28f951/input/CompileSourceACC062:67:7: error: unknown type name 'uint32_t'; did you mean '__hip_internal::uint32_t'? [rank3]: 67 | uint32_t int32; [rank3]: | ^~~~~~~~ [rank3]: | __hip_internal::uint32_t ``` Earlier uint32_t was defined in HIP headers in std namespace. Now it is moved to __hip_internal namespace in hip headers. This change is made in ROCm 7.0. --- .../jit/codegen/fuser/cuda/resource_strings.h | 71 ++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index 9728d27d4d79b..0ac2c79d1e98a 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -260,7 +260,7 @@ typedef __half half; )"; #endif -#if defined(USE_ROCM) +#if defined(USE_ROCM) && ROCM_VERSION < 70000 constexpr auto bfloat16_support_literal = R"( #ifndef __align__ @@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) { return val; } +__device__ float __bfloat162float(const __nv_bfloat16 a) { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(a.__x) << 16}; + return u.fp32; +} +#endif /* defined(__cplusplus) */ +)"; +#elif defined(USE_ROCM) && ROCM_VERSION >= 70000 +constexpr auto bfloat16_support_literal = + R"( +#ifndef __align__ +#define __align__(x) __attribute__((aligned(x))) +#endif + +typedef unsigned int uint32_t; + +typedef struct __align__(2) { + unsigned short x; +} +__nv_bfloat16_raw; + +#if defined(__cplusplus) +struct __align__(2) __nv_bfloat16 { + __host__ __device__ __nv_bfloat16() {} + + __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + unsigned short __x; +}; + +__device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); +} + +/* Definitions of intrinsics */ +__device__ __nv_bfloat16 __float2bfloat16(const float a) { + __nv_bfloat16 val; + __nv_bfloat16_raw r; + unsigned int sign; + unsigned int remainder; + r.x = __internal_float2bfloat16(a, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { + r.x++; + } + val = r; + return val; +} + __device__ float __bfloat162float(const __nv_bfloat16 a) { union {