Skip to content

Commit 198f834

Browse files
jroelofsDhruvSrivastavaX
authored andcommitted
[Matrix] Propagate shape information through (f)abs insts (llvm#141704)
1 parent c9eaa5d commit 198f834

File tree

2 files changed

+137
-103
lines changed

2 files changed

+137
-103
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
232232
if (I->isBinaryOp())
233233
return true;
234234

235+
if (auto *II = dyn_cast<IntrinsicInst>(V))
236+
switch (II->getIntrinsicID()) {
237+
case Intrinsic::abs:
238+
case Intrinsic::fabs:
239+
return true;
240+
default:
241+
return false;
242+
}
243+
235244
switch (I->getOpcode()) {
236245
case Instruction::FNeg:
237246
return true;
@@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
618627
case Intrinsic::matrix_column_major_store:
619628
return true;
620629
default:
621-
return false;
630+
return isUniformShape(II);
622631
}
623632
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
624633
}
@@ -1064,8 +1073,8 @@ class LowerMatrixIntrinsics {
10641073
VisitBinaryOperator(BinOp, SI);
10651074
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10661075
VisitUnaryOperator(UnOp, SI);
1067-
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
1068-
VisitCallInst(CInst);
1076+
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
1077+
VisitIntrinsicInst(Intr, SI);
10691078
else if (match(Inst, m_Load(m_Value(Op1))))
10701079
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
10711080
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -1111,23 +1120,48 @@ class LowerMatrixIntrinsics {
11111120
}
11121121

11131122
/// Replace intrinsic calls.
1114-
void VisitCallInst(CallInst *Inst) {
1115-
assert(Inst->getCalledFunction() &&
1116-
Inst->getCalledFunction()->isIntrinsic());
1117-
1118-
switch (Inst->getCalledFunction()->getIntrinsicID()) {
1123+
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
1124+
switch (Inst->getIntrinsicID()) {
11191125
case Intrinsic::matrix_multiply:
11201126
LowerMultiply(Inst);
1121-
break;
1127+
return;
11221128
case Intrinsic::matrix_transpose:
11231129
LowerTranspose(Inst);
1124-
break;
1130+
return;
11251131
case Intrinsic::matrix_column_major_load:
11261132
LowerColumnMajorLoad(Inst);
1127-
break;
1133+
return;
11281134
case Intrinsic::matrix_column_major_store:
11291135
LowerColumnMajorStore(Inst);
1130-
break;
1136+
return;
1137+
case Intrinsic::abs:
1138+
case Intrinsic::fabs: {
1139+
IRBuilder<> Builder(Inst);
1140+
MatrixTy Result;
1141+
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);
1142+
Builder.setFastMathFlags(getFastMathFlags(Inst));
1143+
1144+
for (auto &Vector : M.vectors()) {
1145+
switch (Inst->getIntrinsicID()) {
1146+
case Intrinsic::abs:
1147+
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
1148+
Inst->getOperand(1)));
1149+
continue;
1150+
case Intrinsic::fabs:
1151+
Result.addVector(
1152+
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
1153+
continue;
1154+
default:
1155+
llvm_unreachable("unexpected intrinsic");
1156+
}
1157+
}
1158+
1159+
finalizeLowering(Inst,
1160+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1161+
Result.getNumVectors()),
1162+
Builder);
1163+
return;
1164+
}
11311165
default:
11321166
llvm_unreachable(
11331167
"only intrinsics supporting shape info should be seen here");

0 commit comments

Comments
 (0)