Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ typedef __half half;
)";
#endif

#if defined(USE_ROCM)
#if defined(USE_ROCM) && ROCM_VERSION < 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
Expand Down Expand Up @@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
return val;
}

__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(a.__x) << 16};
return u.fp32;
}
#endif /* defined(__cplusplus) */
)";
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
#define __align__(x) __attribute__((aligned(x)))
#endif

typedef unsigned int uint32_t;

typedef struct __align__(2) {
unsigned short x;
}
__nv_bfloat16_raw;

#if defined(__cplusplus)
struct __align__(2) __nv_bfloat16 {
__host__ __device__ __nv_bfloat16() {}

__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
__x = hr.x;
return *this;
}

unsigned short __x;
};

__device__ unsigned short __internal_float2bfloat16(
const float f,
unsigned int& sign,
unsigned int& remainder) {
unsigned int x;

x = __float_as_uint(f);

if ((x & 0x7fffffffU) > 0x7f800000U) {
sign = 0U;
remainder = 0U;
return static_cast<unsigned short>(0x7fffU);
}
sign = x >> 31;
remainder = x << 16;
return static_cast<unsigned short>(x >> 16);
}

/* Definitions of intrinsics */
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
__nv_bfloat16 val;
__nv_bfloat16_raw r;
unsigned int sign;
unsigned int remainder;
r.x = __internal_float2bfloat16(a, sign, remainder);
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
r.x++;
}
val = r;
return val;
}

__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
Expand Down