Skip to content

Commit

Permalink
Translate complex nested vector expressions instead of lowering them (#…
Browse files Browse the repository at this point in the history
…1335)

Allow translator to go deeper into complex nested instructions.
Enabled a possibility to translate expressions in `extractelement` and binary
operator instructions (e.g., `fadd`, `fmul`).

Also this change removes lowering for constant expression vector as the
alternative approach was introduced - to translate any complicated nested
instruction instead of lowering.
  • Loading branch information
vmaksimo authored and MrSidims committed Mar 18, 2022
1 parent 1c808cd commit d7a0304
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 176 deletions.
39 changes: 1 addition & 38 deletions lib/SPIRV/SPIRVLowerConstExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,53 +167,16 @@ void SPIRVLowerConstExprBase::visit(Module *M) {
};

WorkList.pop_front();
auto LowerConstantVec = [&II, &LowerOp, &WorkList,
&M](ConstantVector *Vec,
unsigned NumOfOp) -> Value * {
if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
return isa<ConstantExpr>(V) || isa<Function>(V);
})) {
// Expand a vector of constexprs and construct it back with
// series of insertelement instructions
std::list<Value *> OpList;
std::transform(Vec->op_begin(), Vec->op_end(),
std::back_inserter(OpList),
[LowerOp](Value *V) { return LowerOp(V); });
Value *Repl = nullptr;
unsigned Idx = 0;
auto *PhiII = dyn_cast<PHINode>(II);
auto *InsPoint =
PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
std::list<Instruction *> ReplList;
for (auto V : OpList) {
if (auto *Inst = dyn_cast<Instruction>(V))
ReplList.push_back(Inst);
Repl = InsertElementInst::Create(
(Repl ? Repl : UndefValue::get(Vec->getType())), V,
ConstantInt::get(Type::getInt32Ty(M->getContext()), Idx++), "",
InsPoint);
}
WorkList.splice(WorkList.begin(), ReplList);
return Repl;
}
return nullptr;
};

for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
auto *Op = II->getOperand(OI);
if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
Value *ReplInst = LowerConstantVec(Vec, OI);
if (ReplInst)
II->replaceUsesOfWith(Op, ReplInst);
} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
if (auto *CE = dyn_cast<ConstantExpr>(Op)) {
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
Metadata *MD = MDAsVal->getMetadata();
if (auto ConstMD = dyn_cast<ConstantAsMetadata>(MD)) {
Constant *C = ConstMD->getValue();
Value *ReplInst = nullptr;
if (auto *Vec = dyn_cast<ConstantVector>(C))
ReplInst = LowerConstantVec(Vec, OI);
if (auto *CE = dyn_cast<ConstantExpr>(C))
ReplInst = LowerOp(CE);
if (ReplInst) {
Expand Down
7 changes: 5 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,8 @@ SPIRVValue *LLVMToSPIRVBase::transValue(Value *V, SPIRVBasicBlock *BB,

SPIRVDBG(dbgs() << "[transValue] " << *V << '\n');
assert((!isa<Instruction>(V) || isa<GetElementPtrInst>(V) ||
isa<CastInst>(V) || BB) &&
isa<CastInst>(V) || isa<ExtractElementInst>(V) ||
isa<BinaryOperator>(V) || BB) &&
"Invalid SPIRV BB");

auto BV = transValueWithoutDecoration(V, BB, CreateForward, FuncTrans);
Expand All @@ -978,7 +979,9 @@ SPIRVInstruction *LLVMToSPIRVBase::transBinaryInst(BinaryOperator *B,
transBoolOpCode(Op0, OpCodeMap::map(LLVMOC)), transType(B->getType()),
Op0, transValue(B->getOperand(1), BB), BB);

if (isUnfusedMulAdd(B)) {
// BinaryOperator can have no parent if it is handled as an expression inside
// another instruction.
if (B->getParent() && isUnfusedMulAdd(B)) {
Function *F = B->getFunction();
SPIRVDBG(dbgs() << "[fp-contract] disabled for " << F->getName()
<< ": possible fma candidate " << *B << '\n');
Expand Down
13 changes: 12 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,18 @@ SPIRVSpecConstantOp *createSpecConstantOpInst(SPIRVInstruction *Inst) {
auto OC = Inst->getOpCode();
assert(isSpecConstantOpAllowedOp(OC) &&
"Op code not allowed for OpSpecConstantOp");
auto Ops = Inst->getIds(Inst->getOperands());
std::vector<SPIRVWord> Ops;

// CompositeExtract/Insert operations use zero-based numbering for their
// indexes (containted in instruction operands). All their operands are
// Literals, so we can pass them as is for further handling.
if (OC == OpCompositeExtract || OC == OpCompositeInsert) {
auto *SPIRVInst = static_cast<SPIRVInstTemplateBase *>(Inst);
Ops = SPIRVInst->getOpWords();
} else {
Ops = Inst->getIds(Inst->getOperands());
}

Ops.insert(Ops.begin(), OC);
return static_cast<SPIRVSpecConstantOp *>(SPIRVSpecConstantOp::create(
OpSpecConstantOp, Inst->getType(), Inst->getId(), Ops, nullptr,
Expand Down
20 changes: 19 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2757,11 +2757,29 @@ _SPIRV_OP(ImageQuerySamples, true, 4)
#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVInstTemplateBase, Op##x, __VA_ARGS__> SPIRV##x;
// Other instructions
_SPIRV_OP(SpecConstantOp, true, 4, true, 0)
_SPIRV_OP(GenericPtrMemSemantics, true, 4, false)
_SPIRV_OP(GenericCastToPtrExplicit, true, 5, false, 1)
#undef _SPIRV_OP

class SPIRVSpecConstantOpBase : public SPIRVInstTemplateBase {
public:
bool isOperandLiteral(unsigned I) const override {
// If SpecConstant results from CompositeExtract/Insert operation, then all
// operands are expected to be literals.
switch (Ops[0]) { // Opcode of underlying SpecConstant operation
case OpCompositeExtract:
case OpCompositeInsert:
return true;
default:
return SPIRVInstTemplateBase::isOperandLiteral(I);
}
}
};

typedef SPIRVInstTemplate<SPIRVSpecConstantOpBase, OpSpecConstantOp, true, 4,
true, 0>
SPIRVSpecConstantOp;

class SPIRVAssumeTrueKHR : public SPIRVInstruction {
public:
static const Op OC = OpAssumeTrueKHR;
Expand Down
94 changes: 94 additions & 0 deletions test/complex-constexpr-vector.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64"

define linkonce_odr hidden spir_func void @foo() {
entry:
; CHECK-SPIRV-DAG: Constant [[#]] [[#CONSTANT1:]] 65793
; CHECK-SPIRV-DAG: Constant [[#]] [[#CONSTANT2:]] 131586

; CHECK-SPIRV: ConstantComposite [[#]] [[#COMPOS0:]] [[#CONSTANT1]]
; 124 is OpBitcast opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES0:]] 124 [[#COMPOS0]]

; 81 is OpCompositeExtract opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES0:]] 81 [[#BITCAST_RES0]] 0
; CHECK-SPIRV: ConstantComposite [[#]] [[#COMPOS1:]] [[#CONSTANT2]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES1:]] 124 [[#COMPOS1]]
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES1:]] 81 [[#BITCAST_RES1]] 0
; 129 is OpFAdd opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_1:]] 129 [[#EXTRACT_RES0:]] [[#EXTRACT_RES1]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES2:]] 81 [[#BITCAST_RES0]] 1
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES3:]] 81 [[#BITCAST_RES1]] 1
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_2:]] 129 [[#EXTRACT_RES2]] [[#EXTRACT_RES3]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES2:]] 81 [[#BITCAST_RES0]] 2
; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES2:]] 81 [[#BITCAST_RES1]] 2
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_3:]] 129 [[#]] [[#BITCAST_RES2]]

; CHECK-SPIRV: Undef [[#]] [[#MEMBER_4:]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#FINAL_COMPOS:]] [[#MEMBER_1]] [[#MEMBER_2]] [[#MEMBER_3]] [[#MEMBER_4]]
; CHECK-SPIRV: DebugValue [[#]] [[#FINAL_COMPOS]]

; CHECK-LLVM: call void @llvm.dbg.value(
; CHECK-LLVM-SAME: metadata <4 x half> <
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 0),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 0)),
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 1),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 1)),
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 2),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 2)),
; CHECK-LLVM-SAME: half undef>,
; CHECK-LLVM-SAME: metadata ![[#]], metadata !DIExpression()), !dbg ![[#]]
call void @llvm.dbg.value(
metadata <4 x half> <
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 0),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 0)),
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 1),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 1)),
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 2),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 2)),
half undef>,
metadata !12, metadata !DIExpression()), !dbg !7
ret void
}

; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
declare void @llvm.dbg.value(metadata, metadata, metadata)

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!3, !4}
!opencl.used.extensions = !{!2}
!opencl.used.optional.core.features = !{!2}
!opencl.compiler.options = !{!2}
!llvm.ident = !{!5}

!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang version 13.0.0 (https://github.com/intel/llvm.git)", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !2, nameTableKind: None)
!1 = !DIFile(filename: "main.cpp", directory: "/export/users")
!2 = !{}
!3 = !{i32 2, !"Debug Info Version", i32 3}
!4 = !{i32 1, !"wchar_size", i32 4}
!5 = !{!"clang version 13.0.0"}
!6 = distinct !DISubprogram(name: "main", scope: !1, file: !1, line: 1, type: !8, scopeLine: 4, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition, unit: !0, retainedNodes: !2)
!7 = !DILocation(line: 1, scope: !6, inlinedAt: !11)
!8 = !DISubroutineType(types: !9)
!9 = !{!10}
!10 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed)
!11 = !DILocation(line: 1, column: 0, scope: !6)
!12 = !DILocalVariable(name: "resVec", scope: !6, file: !1, line: 1, type: !13)
!13 = distinct !DICompositeType(tag: DW_TAG_class_type, name: "vec<cl::sycl::detail::half_impl::half, 3>", scope: !6, file: !1, line: 1, size: 64, flags: DIFlagTypePassByValue, elements: !2)
15 changes: 9 additions & 6 deletions test/constexpr_phi.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
; RUN: FileCheck < %t.r.ll %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: Name [[#F:]] "_Z3runiiPi"

; 117 is OpConvertPtrToU opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#SpecConst0:]] 117 [[#F1Ptr:]]
; CHECK-SPIRV: SpecConstantOp [[#]] [[#SpecConst1:]] 117 [[#F2Ptr:]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#Compos0:]] [[#SpecConst0]] [[#SpecConst0]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#Compos1:]] [[#SpecConst0]] [[#SpecConst1]]

; CHECK-SPIRV: Function [[#]] [[#F]] [[#]] [[#]]
; CHECK-SPIRV: Label [[#L1:]]
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins1:]] [[#]] [[#]] 0
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins2:]] [[#]] [[#Ins1]] 1
; CHECK-SPIRV: BranchConditional [[#]] [[#L2:]] [[#L3:]]
; CHECK-SPIRV: Label [[#L2]]
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins3:]] [[#]] [[#]] 0
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins4:]] [[#]] [[#Ins3]] 1
; CHECK-SPIRV: Branch [[#L3]]
; CHECK-SPIRV: Label [[#L3]]
; CHECK-NEXT-SPIRV: Phi [[#]] [[#]]
; CHECK-SAME-SPIRV: [[#Ins2]] [[#L1]]
; CHECK-SAME-SPIRV: [[#Ins4]] [[#L2]]
; CHECK-SAME-SPIRV: [[#Compos0]] [[#L1]]
; CHECK-SAME-SPIRV: [[#Compos1]] [[#L2]]

; CHECK-LLVM: br label %[[#L:]]
; CHECK-LLVM: [[#L]]:
Expand Down
Loading

0 comments on commit d7a0304

Please sign in to comment.