Skip to content

Commit 0b5ffd2

Browse files
committed
[Matrix] Propagate shape information through (f)abs insts
1 parent 79ae407 commit 0b5ffd2

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
232232
if (I->isBinaryOp())
233233
return true;
234234

235+
if (auto *II = dyn_cast<IntrinsicInst>(V))
236+
switch (II->getIntrinsicID()) {
237+
case Intrinsic::abs:
238+
case Intrinsic::fabs:
239+
return true;
240+
default:
241+
return false;
242+
}
243+
235244
switch (I->getOpcode()) {
236245
case Instruction::FNeg:
237246
return true;
@@ -621,7 +630,7 @@ class LowerMatrixIntrinsics {
621630
case Intrinsic::matrix_column_major_store:
622631
return true;
623632
default:
624-
return false;
633+
return isUniformShape(II);
625634
}
626635
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
627636
}
@@ -1127,6 +1136,9 @@ class LowerMatrixIntrinsics {
11271136
case Intrinsic::matrix_column_major_store:
11281137
LowerColumnMajorStore(Inst);
11291138
break;
1139+
case Intrinsic::abs:
1140+
case Intrinsic::fabs:
1141+
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
11301142
default:
11311143
return false;
11321144
}
@@ -2198,6 +2210,49 @@ class LowerMatrixIntrinsics {
21982210
return true;
21992211
}
22002212

2213+
/// Lower uniform shape intrinsics, if shape information is available.
2214+
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
2215+
auto I = ShapeMap.find(Inst);
2216+
if (I == ShapeMap.end())
2217+
return false;
2218+
2219+
IRBuilder<> Builder(Inst);
2220+
ShapeInfo &Shape = I->second;
2221+
2222+
MatrixTy Result;
2223+
2224+
switch (Inst->getIntrinsicID()) {
2225+
case Intrinsic::abs:
2226+
case Intrinsic::fabs: {
2227+
Value *Op = Inst->getOperand(0);
2228+
2229+
MatrixTy M = getMatrix(Op, Shape, Builder);
2230+
2231+
Builder.setFastMathFlags(getFastMathFlags(Inst));
2232+
2233+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2234+
switch (Inst->getIntrinsicID()) {
2235+
case Intrinsic::abs:
2236+
Result.addVector(Builder.CreateBinaryIntrinsic(
2237+
Intrinsic::abs, M.getVector(I), Inst->getOperand(1)));
2238+
break;
2239+
case Intrinsic::fabs:
2240+
Result.addVector(Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(),
2241+
M.getVector(I)));
2242+
break;
2243+
}
2244+
2245+
finalizeLowering(Inst,
2246+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2247+
Result.getNumVectors()),
2248+
Builder);
2249+
return true;
2250+
}
2251+
default:
2252+
llvm_unreachable("unexpected intrinsic");
2253+
}
2254+
}
2255+
22012256
/// Helper to linearize a matrix expression tree into a string. Currently
22022257
/// matrix expressions are linarized by starting at an expression leaf and
22032258
/// linearizing bottom up.

llvm/test/Transforms/LowerMatrixIntrinsics/binop.ll

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,46 @@ define void @xor_2x2(ptr %lhs, ptr %rhs, ptr %out) {
432432
store <4 x i32> %optt, ptr %out
433433
ret void
434434
}
435+
436+
define void @fabs_2x2f64(ptr %in, ptr %out) {
437+
; CHECK-LABEL: @fabs_2x2f64(
438+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
439+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
440+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
441+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD]])
442+
; CHECK-NEXT: [[TMP2:%.*]] = call <2 x double> @llvm.fabs.v2f64(<2 x double> [[COL_LOAD1]])
443+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
444+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
445+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
446+
; CHECK-NEXT: ret void
447+
;
448+
%load = load <4 x double>, ptr %in
449+
%fabs = call <4 x double> @llvm.fabs.v4f64(<4 x double> %load)
450+
%fabst = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabs, i32 2, i32 2)
451+
%fabstt = call <4 x double> @llvm.matrix.transpose(<4 x double> %fabst, i32 2, i32 2)
452+
store <4 x double> %fabstt, ptr %out
453+
ret void
454+
}
455+
456+
define void @fabs_2x2i32(ptr %in, ptr %out) {
457+
; CHECK-LABEL: @fabs_2x2i32(
458+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i32>, ptr [[IN:%.*]], align 16
459+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 2
460+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i32>, ptr [[VEC_GEP]], align 8
461+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD]], i1 false)
462+
; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[COL_LOAD1]], i1 false)
463+
; CHECK-NEXT: [[TMP3:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP1]], i1 true)
464+
; CHECK-NEXT: [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP2]], i1 true)
465+
; CHECK-NEXT: store <2 x i32> [[TMP3]], ptr [[OUT:%.*]], align 16
466+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
467+
; CHECK-NEXT: store <2 x i32> [[TMP4]], ptr [[VEC_GEP2]], align 8
468+
; CHECK-NEXT: ret void
469+
;
470+
%load = load <4 x i32>, ptr %in
471+
%abs = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %load, i1 false)
472+
%abst = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abs, i32 2, i32 2)
473+
%abstt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %abst, i32 2, i32 2)
474+
%absabstt = call <4 x i32> @llvm.abs.v4i32(<4 x i32> %abstt, i1 true)
475+
store <4 x i32> %absabstt, ptr %out
476+
ret void
477+
}

0 commit comments

Comments
 (0)