diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index ff2ef1f2377ce..d40ac15f132cf 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -260,13 +260,82 @@ typedef __half half; )"; #endif -#if defined(USE_ROCM) +#if defined(USE_ROCM) && ROCM_VERSION < 70000 +constexpr auto bfloat16_support_literal = + R"( +#ifndef __align__ +#define __align__(x) __attribute__((aligned(x))) +#endif + +typedef struct __align__(2) { + unsigned short x; +} +__nv_bfloat16_raw; + +#if defined(__cplusplus) +struct __align__(2) __nv_bfloat16 { + __host__ __device__ __nv_bfloat16() {} + + __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + unsigned short __x; +}; + +__device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); +} + +/* Definitions of intrinsics */ +__device__ __nv_bfloat16 __float2bfloat16(const float a) { + __nv_bfloat16 val; + __nv_bfloat16_raw r; + unsigned int sign; + unsigned int remainder; + r.x = __internal_float2bfloat16(a, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { + r.x++; + } + val = r; + return val; +} + +__device__ float __bfloat162float(const __nv_bfloat16 a) { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(a.__x) << 16}; + return u.fp32; +} +#endif /* defined(__cplusplus) */ +)"; +#elif defined(USE_ROCM) && ROCM_VERSION >= 70000 constexpr auto bfloat16_support_literal = R"( #ifndef __align__ #define __align__(x) __attribute__((aligned(x))) #endif +typedef unsigned int uint32_t; + typedef struct __align__(2) { unsigned short x; }