Skip to content

Reland [SPIR-V] Support SPV_INTEL_int4 extension #141279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.

To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_ternary_bitwise_function",
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
{"SPV_INTEL_2d_block_io",
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
21 changes: 17 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
return Width;
if (Width <= 8)
Width = 8;
Expand All @@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
if ((!isPowerOf2_32(Width) || Width < 8) &&
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_int4);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::Int4TypeINTEL);
} else if ((!isPowerOf2_32(Width) || Width < 8) &&
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
Expand Down Expand Up @@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
const Type *ET = getTypeForSPIRVType(ElemType);
if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
}
return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
bool IsExtendedInts =
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
auto extendedScalarsAndVectors =
[IsExtendedInts](const LegalityQuery &Query) {
const LLT Ty = Query.Types[0];
Expand Down
41 changes: 30 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,31 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
// To support current approach and limitations wrt. bit width here we widen a
// scalar register with a bit width greater than 1 to valid sizes and cap it to
// 64 width.
static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
static unsigned widenBitWidthToNextPow2(unsigned BitWidth) {
if (BitWidth == 1)
return 1; // No need to widen 1-bit values
return std::min(std::max(1u << Log2_32_Ceil(BitWidth), 8u), 64u);
}

static void widenScalarType(Register Reg, MachineRegisterInfo &MRI) {
LLT RegType = MRI.getType(Reg);
if (!RegType.isScalar())
return;
unsigned Sz = RegType.getScalarSizeInBits();
if (Sz == 1)
return;
unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
if (NewSz != Sz)
MRI.setType(Reg, LLT::scalar(NewSz));
unsigned CurrentWidth = RegType.getScalarSizeInBits();
unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
if (NewWidth != CurrentWidth)
MRI.setType(Reg, LLT::scalar(NewWidth));
}

static void widenCImmType(MachineOperand &MOP) {
const ConstantInt *CImmVal = MOP.getCImm();
unsigned CurrentWidth = CImmVal->getBitWidth();
unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
if (NewWidth != CurrentWidth) {
// Replace the immediate value with the widened version
MOP.setCImm(ConstantInt::get(CImmVal->getType()->getContext(),
CImmVal->getValue().zextOrTrunc(NewWidth)));
}
}

static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
Expand Down Expand Up @@ -492,7 +507,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
bool IsExtendedInts =
ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);

for (MachineBasicBlock *MBB : post_order(&MF)) {
if (MBB->empty())
Expand All @@ -505,10 +521,13 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
unsigned MIOp = MI.getOpcode();

if (!IsExtendedInts) {
// validate bit width of scalar registers
for (const auto &MOP : MI.operands())
// validate bit width of scalar registers and constant immediates
for (auto &MOP : MI.operands()) {
if (MOP.isReg())
widenScalarLLTNextPow2(MOP.getReg(), MRI);
widenScalarType(MOP.getReg(), MRI);
else if (MOP.isCImm())
widenCImmType(MOP);
}
}

if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
defm SPV_INTEL_int4 : ExtensionOperand<123>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: Capability Int4TypeINTEL
; CHECK-DAG: Capability CooperativeMatrixKHR
; CHECK-DAG: Extension "SPV_INTEL_int4"
; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"

; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
; CHECK: CompositeConstruct %[[#CoopMatTy]]

define spir_kernel void @foo() {
entry:
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
ret void
}

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
29 changes: 29 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4

; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made
; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.

; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1

; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]

define spir_kernel void @foo() {
entry:
%0 = alloca i4
store i4 1, ptr %0
%1 = load i4, ptr %0
call spir_func void @boo(i4 %1)
ret void
}

declare spir_func void @boo(i4)
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}

; CHECK: Capability Int4TypeINTEL
; CHECK: Extension "SPV_INTEL_int4"
; CHECK: %[[#Int4:]] = OpTypeInt 4 0
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1

; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]

define spir_kernel void @foo() {
entry:
%0 = alloca i4
store i4 1, ptr %0
%1 = load i4, ptr %0
call spir_func void @boo(i4 %1)
ret void
}

declare spir_func void @boo(i4)
Loading