@@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
232
232
if (I->isBinaryOp ())
233
233
return true ;
234
234
235
+ if (auto *II = dyn_cast<IntrinsicInst>(V))
236
+ switch (II->getIntrinsicID ()) {
237
+ case Intrinsic::abs:
238
+ case Intrinsic::fabs:
239
+ return true ;
240
+ default :
241
+ return false ;
242
+ }
243
+
235
244
switch (I->getOpcode ()) {
236
245
case Instruction::FNeg:
237
246
return true ;
@@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
618
627
case Intrinsic::matrix_column_major_store:
619
628
return true ;
620
629
default :
621
- return false ;
630
+ return isUniformShape (II) ;
622
631
}
623
632
return isUniformShape (V) || isa<StoreInst>(V) || isa<LoadInst>(V);
624
633
}
@@ -1064,8 +1073,8 @@ class LowerMatrixIntrinsics {
1064
1073
VisitBinaryOperator (BinOp, SI);
1065
1074
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1066
1075
VisitUnaryOperator (UnOp, SI);
1067
- else if (CallInst *CInst = dyn_cast<CallInst >(Inst))
1068
- VisitCallInst (CInst );
1076
+ else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst >(Inst))
1077
+ VisitIntrinsicInst (Intr, SI );
1069
1078
else if (match (Inst, m_Load (m_Value (Op1))))
1070
1079
VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder);
1071
1080
else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
@@ -1111,23 +1120,48 @@ class LowerMatrixIntrinsics {
1111
1120
}
1112
1121
1113
1122
// / Replace intrinsic calls.
1114
- void VisitCallInst (CallInst *Inst) {
1115
- assert (Inst->getCalledFunction () &&
1116
- Inst->getCalledFunction ()->isIntrinsic ());
1117
-
1118
- switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
1123
+ void VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &Shape) {
1124
+ switch (Inst->getIntrinsicID ()) {
1119
1125
case Intrinsic::matrix_multiply:
1120
1126
LowerMultiply (Inst);
1121
- break ;
1127
+ return ;
1122
1128
case Intrinsic::matrix_transpose:
1123
1129
LowerTranspose (Inst);
1124
- break ;
1130
+ return ;
1125
1131
case Intrinsic::matrix_column_major_load:
1126
1132
LowerColumnMajorLoad (Inst);
1127
- break ;
1133
+ return ;
1128
1134
case Intrinsic::matrix_column_major_store:
1129
1135
LowerColumnMajorStore (Inst);
1130
- break ;
1136
+ return ;
1137
+ case Intrinsic::abs:
1138
+ case Intrinsic::fabs: {
1139
+ IRBuilder<> Builder (Inst);
1140
+ MatrixTy Result;
1141
+ MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape, Builder);
1142
+ Builder.setFastMathFlags (getFastMathFlags (Inst));
1143
+
1144
+ for (auto &Vector : M.vectors ()) {
1145
+ switch (Inst->getIntrinsicID ()) {
1146
+ case Intrinsic::abs:
1147
+ Result.addVector (Builder.CreateBinaryIntrinsic (Intrinsic::abs, Vector,
1148
+ Inst->getOperand (1 )));
1149
+ continue ;
1150
+ case Intrinsic::fabs:
1151
+ Result.addVector (
1152
+ Builder.CreateUnaryIntrinsic (Inst->getIntrinsicID (), Vector));
1153
+ continue ;
1154
+ default :
1155
+ llvm_unreachable (" unexpected intrinsic" );
1156
+ }
1157
+ }
1158
+
1159
+ finalizeLowering (Inst,
1160
+ Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1161
+ Result.getNumVectors ()),
1162
+ Builder);
1163
+ return ;
1164
+ }
1131
1165
default :
1132
1166
llvm_unreachable (
1133
1167
" only intrinsics supporting shape info should be seen here" );
0 commit comments