Skip to content

Commit 3296ce5

Browse files
committed
[Matrix] Propagate shape information through cast instructions
1 parent 79ae407 commit 3296ce5

File tree

2 files changed

+240
-3
lines changed

2 files changed

+240
-3
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
#include "llvm/IR/CFG.h"
3333
#include "llvm/IR/DataLayout.h"
3434
#include "llvm/IR/DebugInfoMetadata.h"
35+
#include "llvm/IR/DerivedTypes.h"
3536
#include "llvm/IR/Function.h"
3637
#include "llvm/IR/IRBuilder.h"
38+
#include "llvm/IR/InstrTypes.h"
3739
#include "llvm/IR/Instructions.h"
3840
#include "llvm/IR/IntrinsicInst.h"
3941
#include "llvm/IR/MatrixBuilder.h"
@@ -229,7 +231,7 @@ static bool isUniformShape(Value *V) {
229231
if (!I)
230232
return true;
231233

232-
if (I->isBinaryOp())
234+
if (I->isBinaryOp() || I->isCast())
233235
return true;
234236

235237
switch (I->getOpcode()) {
@@ -1066,9 +1068,11 @@ class LowerMatrixIntrinsics {
10661068
Value *Op2;
10671069
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
10681070
Changed |= VisitBinaryOperator(BinOp);
1069-
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1071+
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10701072
Changed |= VisitUnaryOperator(UnOp);
1071-
if (match(Inst, m_Load(m_Value(Op1))))
1073+
else if (auto *Cast = dyn_cast<CastInst>(Inst))
1074+
Changed |= VisitCastInstruction(Cast);
1075+
else if (match(Inst, m_Load(m_Value(Op1))))
10721076
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
10731077
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
10741078
Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
@@ -2198,6 +2202,37 @@ class LowerMatrixIntrinsics {
21982202
return true;
21992203
}
22002204

2205+
/// Lower cast instructions, if shape information is available.
2206+
bool VisitCastInstruction(CastInst *Inst) {
2207+
auto I = ShapeMap.find(Inst);
2208+
if (I == ShapeMap.end())
2209+
return false;
2210+
2211+
Value *Op = Inst->getOperand(0);
2212+
2213+
IRBuilder<> Builder(Inst);
2214+
ShapeInfo &Shape = I->second;
2215+
2216+
MatrixTy Result;
2217+
MatrixTy M = getMatrix(Op, Shape, Builder);
2218+
2219+
Builder.setFastMathFlags(getFastMathFlags(Inst));
2220+
2221+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
2222+
auto *OrigTy = cast<VectorType>(Inst->getType());
2223+
auto *NewTy = VectorType::get(OrigTy->getElementType(),
2224+
ElementCount::getFixed(M.getStride()));
2225+
Result.addVector(
2226+
Builder.CreateCast(Inst->getOpcode(), M.getVector(I), NewTy));
2227+
}
2228+
2229+
finalizeLowering(Inst,
2230+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2231+
Result.getNumVectors()),
2232+
Builder);
2233+
return true;
2234+
}
2235+
22012236
/// Helper to linearize a matrix expression tree into a string. Currently
22022237
/// matrix expressions are linarized by starting at an expression leaf and
22032238
/// linearizing bottom up.
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define void @fneg_2x2(ptr %in, ptr %out) {
5+
; CHECK-LABEL: @fneg_2x2(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]]
10+
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]]
11+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
12+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
13+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
14+
; CHECK-NEXT: ret void
15+
;
16+
%inv = load <4 x float>, ptr %in
17+
%op = fneg <4 x float> %inv
18+
%opt = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
19+
%optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
20+
store <4 x float> %optt, ptr %out
21+
ret void
22+
}
23+
24+
define void @trunc_2x2(ptr %in, ptr %out) {
25+
; CHECK-LABEL: @trunc_2x2(
26+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
27+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
28+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
29+
; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32>
30+
; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32>
31+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
32+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
33+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
34+
; CHECK-NEXT: ret void
35+
;
36+
%inv = load <4 x i64>, ptr %in
37+
%op = trunc <4 x i64> %inv to <4 x i32>
38+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
39+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
40+
store <4 x i32> %optt, ptr %out
41+
ret void
42+
}
43+
44+
define void @zext_2x2(ptr %in, ptr %out) {
45+
; CHECK-LABEL: @zext_2x2(
46+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i16>, ptr [[IN:%.*]], align 8
47+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i16, ptr [[IN]], i64 2
48+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4
49+
; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32>
50+
; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32>
51+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
52+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
53+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
54+
; CHECK-NEXT: ret void
55+
;
56+
%inv = load <4 x i16>, ptr %in
57+
%op = zext <4 x i16> %inv to <4 x i32>
58+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
59+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
60+
store <4 x i32> %optt, ptr %out
61+
ret void
62+
}
63+
64+
define void @sext_2x2(ptr %in, ptr %out) {
65+
; CHECK-LABEL: @sext_2x2(
66+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 4
67+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, ptr [[IN]], i64 2
68+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2
69+
; CHECK-NEXT: [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16>
70+
; CHECK-NEXT: [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16>
71+
; CHECK-NEXT: store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 8
72+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2
73+
; CHECK-NEXT: store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 4
74+
; CHECK-NEXT: ret void
75+
;
76+
%inv = load <4 x i8>, ptr %in
77+
%op = sext <4 x i8> %inv to <4 x i16>
78+
%opt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %op, i32 2, i32 2)
79+
%optt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %opt, i32 2, i32 2)
80+
store <4 x i16> %optt, ptr %out
81+
ret void
82+
}
83+
84+
define void @fptoui_2x2(ptr %in, ptr %out) {
85+
; CHECK-LABEL: @fptoui_2x2(
86+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
87+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
88+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
89+
; CHECK-NEXT: [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32>
90+
; CHECK-NEXT: [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32>
91+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
92+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
93+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
94+
; CHECK-NEXT: ret void
95+
;
96+
%inv = load <4 x float>, ptr %in
97+
%op = fptoui <4 x float> %inv to <4 x i32>
98+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
99+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
100+
store <4 x i32> %optt, ptr %out
101+
ret void
102+
}
103+
104+
define void @fptosi_2x2(ptr %in, ptr %out) {
105+
; CHECK-LABEL: @fptosi_2x2(
106+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
107+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
108+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
109+
; CHECK-NEXT: [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32>
110+
; CHECK-NEXT: [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32>
111+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
112+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
113+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
114+
; CHECK-NEXT: ret void
115+
;
116+
%inv = load <4 x float>, ptr %in
117+
%op = fptosi <4 x float> %inv to <4 x i32>
118+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
119+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
120+
store <4 x i32> %optt, ptr %out
121+
ret void
122+
}
123+
124+
define void @uitofp_2x2(ptr %in, ptr %out) {
125+
; CHECK-LABEL: @uitofp_2x2(
126+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
127+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
128+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
129+
; CHECK-NEXT: [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double>
130+
; CHECK-NEXT: [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
131+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
132+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
133+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
134+
; CHECK-NEXT: ret void
135+
;
136+
%inv = load <4 x i64>, ptr %in
137+
%op = uitofp <4 x i64> %inv to <4 x double>
138+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
139+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
140+
store <4 x double> %optt, ptr %out
141+
ret void
142+
}
143+
144+
define void @sitofp_2x2(ptr %in, ptr %out) {
145+
; CHECK-LABEL: @sitofp_2x2(
146+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
147+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
148+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
149+
; CHECK-NEXT: [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double>
150+
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
151+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
152+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
153+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
154+
; CHECK-NEXT: ret void
155+
;
156+
%inv = load <4 x i64>, ptr %in
157+
%op = sitofp <4 x i64> %inv to <4 x double>
158+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
159+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
160+
store <4 x double> %optt, ptr %out
161+
ret void
162+
}
163+
164+
define void @fptrunc_2x2(ptr %in, ptr %out) {
165+
; CHECK-LABEL: @fptrunc_2x2(
166+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
167+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
168+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
169+
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float>
170+
; CHECK-NEXT: [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float>
171+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
172+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
173+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
174+
; CHECK-NEXT: ret void
175+
;
176+
%inv = load <4 x double>, ptr %in
177+
%op = fptrunc nnan <4 x double> %inv to <4 x float>
178+
%opt = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
179+
%optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
180+
store <4 x float> %optt, ptr %out
181+
ret void
182+
}
183+
184+
define void @fpext_2x2(ptr %in, ptr %out) {
185+
; CHECK-LABEL: @fpext_2x2(
186+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
187+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
188+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
189+
; CHECK-NEXT: [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double>
190+
; CHECK-NEXT: [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double>
191+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
192+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
193+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
194+
; CHECK-NEXT: ret void
195+
;
196+
%inv = load <4 x float>, ptr %in
197+
%op = fpext <4 x float> %inv to <4 x double>
198+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
199+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
200+
store <4 x double> %optt, ptr %out
201+
ret void
202+
}

0 commit comments

Comments
 (0)