Skip to content

Commit ca5b71a

Browse files
authored
[Matrix] Propagate shape information through Select insts (#141876)
1 parent b6a56b8 commit ca5b71a

File tree

2 files changed

+188
-5
lines changed

2 files changed

+188
-5
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,11 @@ computeShapeInfoForInst(Instruction *I,
323323
return OpShape->second;
324324
}
325325

326-
if (isUniformShape(I)) {
326+
if (isUniformShape(I) || isa<SelectInst>(I)) {
327+
auto Ops = I->operands();
328+
auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
327329
// Find the first operand that has a known shape and use that.
328-
for (auto &Op : I->operands()) {
330+
for (auto &Op : ShapedOps) {
329331
auto OpShape = ShapeMap.find(Op.get());
330332
if (OpShape != ShapeMap.end())
331333
return OpShape->second;
@@ -701,7 +703,8 @@ class LowerMatrixIntrinsics {
701703
default:
702704
return isUniformShape(II);
703705
}
704-
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
706+
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
707+
isa<SelectInst>(V);
705708
}
706709

707710
/// Propagate the shape information of instructions to their users.
@@ -788,10 +791,12 @@ class LowerMatrixIntrinsics {
788791
} else if (isa<StoreInst>(V)) {
789792
// Nothing to do. We forward-propagated to this so we would just
790793
// backward propagate to an instruction with an already known shape.
791-
} else if (isUniformShape(V)) {
794+
} else if (isUniformShape(V) || isa<SelectInst>(V)) {
795+
auto Ops = cast<Instruction>(V)->operands();
796+
auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
792797
// Propagate to all operands.
793798
ShapeInfo Shape = ShapeMap[V];
794-
for (Use &U : cast<Instruction>(V)->operands()) {
799+
for (Use &U : ShapedOps) {
795800
if (setShapeInfo(U.get(), Shape))
796801
pushInstruction(U.get(), WorkList);
797802
}
@@ -1148,6 +1153,8 @@ class LowerMatrixIntrinsics {
11481153
Result = VisitUnaryOperator(UnOp, SI);
11491154
else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
11501155
Result = VisitIntrinsicInst(Intr, SI);
1156+
else if (auto *Select = dyn_cast<SelectInst>(Inst))
1157+
Result = VisitSelectInst(Select, SI);
11511158
else if (match(Inst, m_Load(m_Value(Op1))))
11521159
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
11531160
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2307,6 +2314,36 @@ class LowerMatrixIntrinsics {
23072314
Result.getNumVectors());
23082315
}
23092316

2317+
/// Lower selects.
2318+
MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
2319+
Value *Cond = Inst->getOperand(0);
2320+
Value *OpA = Inst->getOperand(1);
2321+
Value *OpB = Inst->getOperand(2);
2322+
2323+
IRBuilder<> Builder(Inst);
2324+
2325+
MatrixTy Result;
2326+
MatrixTy A = getMatrix(OpA, Shape, Builder);
2327+
MatrixTy B = getMatrix(OpB, Shape, Builder);
2328+
2329+
Value *CondV[2];
2330+
if (isa<FixedVectorType>(Cond->getType())) {
2331+
MatrixTy C = getMatrix(Cond, Shape, Builder);
2332+
CondV[0] = C.getVector(0);
2333+
CondV[1] = C.getVector(1);
2334+
} else {
2335+
CondV[0] = Cond;
2336+
CondV[1] = Cond;
2337+
}
2338+
2339+
for (unsigned I = 0, E = Shape.getNumVectors(); I != E; ++I)
2340+
Result.addVector(
2341+
Builder.CreateSelect(CondV[I], A.getVector(I), B.getVector(I)));
2342+
2343+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2344+
Result.getNumVectors());
2345+
}
2346+
23102347
/// Helper to linearize a matrix expression tree into a string. Currently
23112348
/// matrix expressions are linarized by starting at an expression leaf and
23122349
/// linearizing bottom up.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
5+
; CHECK-LABEL: @select_2x2_bot(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
10+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
11+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
12+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
13+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
14+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
15+
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
16+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4
17+
; CHECK-NEXT: ret void
18+
;
19+
%lhsv = load <4 x float>, ptr %lhs
20+
%rhsv = load <4 x float>, ptr %rhs
21+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
22+
call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
23+
ret void
24+
}
25+
26+
define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
27+
; CHECK-LABEL: @select_2x2_lhs(
28+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4
29+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
30+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4
31+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
32+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
33+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
34+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
35+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
36+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
37+
; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
38+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
39+
; CHECK-NEXT: ret void
40+
;
41+
%lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2)
42+
%rhsv = load <4 x float>, ptr %rhs
43+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
44+
store <4 x float> %op, ptr %out
45+
ret void
46+
}
47+
48+
define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
49+
; CHECK-LABEL: @select_2x2_rhs(
50+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
51+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2
52+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
53+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
54+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
55+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
56+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
57+
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
58+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
59+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
60+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
61+
; CHECK-NEXT: ret void
62+
;
63+
%lhsv = load <4 x float>, ptr %lhs
64+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
65+
%op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
66+
store <4 x float> %op, ptr %out
67+
ret void
68+
}
69+
70+
define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
71+
; CHECK-LABEL: @select_2x2_vcond_shape1(
72+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
73+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
74+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
75+
; CHECK-NEXT: [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
76+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
77+
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
78+
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
79+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
80+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
81+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
82+
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
83+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
84+
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
85+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
86+
; CHECK-NEXT: ret void
87+
;
88+
%lhsv = load <4 x float>, ptr %lhs
89+
%condv = load <4 x i1>, ptr %cond
90+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
91+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
92+
store <4 x float> %op, ptr %out
93+
ret void
94+
}
95+
96+
define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
97+
; CHECK-LABEL: @select_2x2_vcond_shape2(
98+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
99+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
100+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
101+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1
102+
; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2
103+
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
104+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
105+
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2
106+
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
107+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
108+
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
109+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
110+
; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
111+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
112+
; CHECK-NEXT: ret void
113+
;
114+
%lhsv = load <4 x float>, ptr %lhs
115+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2)
116+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
117+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
118+
store <4 x float> %op, ptr %out
119+
ret void
120+
}
121+
122+
define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
123+
; CHECK-LABEL: @select_2x2_vcond_shape3(
124+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
125+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
126+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
127+
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
128+
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
129+
; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2
130+
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4
131+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
132+
; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
133+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]]
134+
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]]
135+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
136+
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2
137+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8
138+
; CHECK-NEXT: ret void
139+
;
140+
%lhsv = load <4 x float>, ptr %lhs
141+
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
142+
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
143+
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
144+
store <4 x float> %op, ptr %out
145+
ret void
146+
}

0 commit comments

Comments
 (0)