Skip to content

Commit

Permalink
Support cl_bf16_conversions (#1441)
Browse files Browse the repository at this point in the history
  • Loading branch information
KornevNikita committed Mar 16, 2022
1 parent 002ede2 commit e95eb30
Show file tree
Hide file tree
Showing 12 changed files with 468 additions and 1 deletion.
120 changes: 120 additions & 0 deletions lib/SPIRV/OCLToSPIRV.cpp
Expand Up @@ -263,6 +263,11 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
void visitCallLdexp(CallInst *CI, StringRef MangledName,
StringRef DemangledName);

/// For cl_intel_convert_bfloat16_as_ushort
void visitCallConvertBFloat16AsUshort(CallInst *CI, StringRef DemangledName);
/// For cl_intel_convert_as_bfloat16_float
void visitCallConvertAsBFloat16Float(CallInst *CI, StringRef DemangledName);

void setOCLTypeToSPIRV(OCLTypeToSPIRVBase *OCLTypeToSPIRV) {
OCLTypeToSPIRVPtr = OCLTypeToSPIRV;
}
Expand Down Expand Up @@ -574,6 +579,24 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
visitCallLdexp(&CI, MangledName, DemangledName);
return;
}
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort ||
DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2 ||
DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3 ||
DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4 ||
DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8 ||
DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
visitCallConvertBFloat16AsUshort(&CI, DemangledName);
return;
}
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
visitCallConvertAsBFloat16Float(&CI, DemangledName);
return;
}
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
}

Expand Down Expand Up @@ -1916,6 +1939,103 @@ void OCLToSPIRVBase::visitCallLdexp(CallInst *CI, StringRef MangledName,
visitCallBuiltinSimple(CI, MangledName, DemangledName);
}

void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
StringRef DemangledName) {
Type *RetTy = CI->getType();
Type *ArgTy = CI->getOperand(0)->getType();
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort) {
if (!RetTy->isIntegerTy(16U) || !ArgTy->isFloatTy())
report_fatal_error(
"OpConvertBFloat16AsUShort must be of i16 and take float");
} else {
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
if (!RetTyVec || !RetTyVec->getElementType()->isIntegerTy(16U) ||
!ArgTyVec || !ArgTyVec->getElementType()->isFloatTy())
report_fatal_error("OpConvertBFloat16NAsUShortN must be of <N x i16> and "
"take <N x float>");
unsigned RetTyVecSize = RetTyVec->getNumElements();
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
if (DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2) {
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
report_fatal_error("ConvertBFloat162AsUShort2 must be of <2 x i16> and "
"take <2 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3) {
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
report_fatal_error("ConvertBFloat163AsUShort3 must be of <3 x i16> and "
"take <3 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4) {
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
report_fatal_error("ConvertBFloat164AsUShort4 must be of <4 x i16> and "
"take <4 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8) {
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
report_fatal_error("ConvertBFloat168AsUShort8 must be of <8 x i16> and "
"take <8 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
report_fatal_error("ConvertBFloat1616AsUShort16 must be of <16 x i16> "
"and take <16 x float>");
}
}

AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return getSPIRVFuncName(internal::OpConvertFToBF16INTEL);
},
&Attrs);
}

void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
StringRef DemangledName) {
Type *RetTy = CI->getType();
Type *ArgTy = CI->getOperand(0)->getType();
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float) {
if (!RetTy->isFloatTy() || !ArgTy->isIntegerTy(16U))
report_fatal_error(
"OpConvertAsBFloat16Float must be of float and take i16");
} else {
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
if (!RetTyVec || !RetTyVec->getElementType()->isFloatTy() || !ArgTyVec ||
!ArgTyVec->getElementType()->isIntegerTy(16U))
report_fatal_error("OpConvertAsBFloat16NFloatN must be of <N x float> "
"and take <N x i16>");
unsigned RetTyVecSize = RetTyVec->getNumElements();
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2) {
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
report_fatal_error("ConvertAsBFloat162Float2 must be of <2 x float> "
"and take <2 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3) {
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
report_fatal_error("ConvertAsBFloat163Float3 must be of <3 x float> "
"and take <3 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4) {
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
report_fatal_error("ConvertAsBFloat164Float4 must be of <4 x float> "
"and take <4 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8) {
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
report_fatal_error("ConvertAsBFloat168Float8 must be of <8 x float> "
"and take <8 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
report_fatal_error("ConvertAsBFloat1616Float16 must be of <16 x float> "
"and take <16 x i16>");
}
}

AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return getSPIRVFuncName(internal::OpConvertBF16ToFINTEL);
},
&Attrs);
}
} // namespace SPIRV

INITIALIZE_PASS_BEGIN(OCLToSPIRVLegacy, "ocl-to-spv",
Expand Down
20 changes: 20 additions & 0 deletions lib/SPIRV/OCLUtil.h
Expand Up @@ -305,6 +305,26 @@ const static char SubgroupBlockWriteINTELPrefix[] =
const static char SubgroupImageMediaBlockINTELPrefix[] =
"intel_sub_group_media_block";
const static char LDEXP[] = "ldexp";
#define _SPIRV_OP(x) \
const static char ConvertBFloat16##x##AsUShort##x[] = \
"intel_convert_bfloat16" #x "_as_ushort" #x;
_SPIRV_OP()
_SPIRV_OP(2)
_SPIRV_OP(3)
_SPIRV_OP(4)
_SPIRV_OP(8)
_SPIRV_OP(16)
#undef _SPIRV_OP
#define _SPIRV_OP(x) \
const static char ConvertAsBFloat16##x##Float##x[] = \
"intel_convert_as_bfloat16" #x "_float" #x;
_SPIRV_OP()
_SPIRV_OP(2)
_SPIRV_OP(3)
_SPIRV_OP(4)
_SPIRV_OP(8)
_SPIRV_OP(16)
#undef _SPIRV_OP
} // namespace kOCLBuiltinName

/// Offset for OpenCL image channel order enumeration values.
Expand Down
31 changes: 31 additions & 0 deletions lib/SPIRV/SPIRVToOCL.cpp
Expand Up @@ -200,6 +200,11 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
visitCallSPIRVRelational(&CI, OC);
return;
}
if (OC == internal::OpConvertFToBF16INTEL ||
OC == internal::OpConvertBF16ToFINTEL) {
visitCallSPIRVBFloat16Conversions(&CI, OC);
return;
}
if (OCLSPIRVBuiltinMap::rfind(OC))
visitCallSPIRVBuiltin(&CI, OC);
}
Expand Down Expand Up @@ -981,6 +986,32 @@ void SPIRVToOCLBase::visitCallSPIRVGenericPtrMemSemantics(CallInst *CI) {
&Attrs);
}

void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
Type *ArgTy = CI->getOperand(0)->getType();
std::string N =
ArgTy->isVectorTy()
? std::to_string(cast<FixedVectorType>(ArgTy)->getNumElements())
: "";
std::string Name;
switch (static_cast<uint32_t>(OC)) {
case internal::OpConvertFToBF16INTEL:
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
break;
case internal::OpConvertBF16ToFINTEL:
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
break;
default:
break; // do nothing
}
return Name;
},
&Attrs);
}

void SPIRVToOCLBase::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
assert(CI->getCalledFunction() && "Unexpected indirect call");
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
Expand Down
7 changes: 7 additions & 0 deletions lib/SPIRV/SPIRVToOCL.h
Expand Up @@ -161,6 +161,13 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase> {
/// %1 = shl i31 %0, 8
void visitCallSPIRVGenericPtrMemSemantics(CallInst *CI);

/// Transform __spirv_ConvertFToBF16INTELDv(N)_f to:
/// intel_convert_bfloat16(N)_as_ushort(N)Dv(N)_f;
/// and transform __spirv_ConvertBF16ToFINTELDv(N)_s to:
/// intel_convert_as_bfloat16(N)_float(N)Dv(N)_t;
/// where N is vector size
void visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC);

/// Transform __spirv_* builtins to OCL 2.0 builtins.
/// No change with arguments.
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);
Expand Down
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertAsBFloat16Float must be of float and take i16

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext 0)
ret void
}

; Function Attrs: convergent
declare spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertAsBFloat16NFloatN must be of <N x float> and take <N x i16>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32> zeroinitializer)
ret void
}

; ; Function Attrs: convergent
declare spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: ConvertAsBFloat162Float2 must be of <2 x float> and take <2 x i16>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertBFloat16AsUShort must be of i16 and take float

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double 0.000000e+00)
ret void
}

; Function Attrs: convergent
declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertBFloat16NAsUShortN must be of <N x i16> and take <N x float>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: ConvertBFloat162AsUShort2 must be of <2 x i16> and take <2 x float>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}

0 comments on commit e95eb30

Please sign in to comment.