Skip to content

Commit 3779c5b

Browse files
committed
[Matrix] Propagate shape information through cast instructions
1 parent 79ae407 commit 3779c5b

File tree

2 files changed

+300
-2
lines changed

2 files changed

+300
-2
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
#include "llvm/IR/CFG.h"
3333
#include "llvm/IR/DataLayout.h"
3434
#include "llvm/IR/DebugInfoMetadata.h"
35+
#include "llvm/IR/DerivedTypes.h"
3536
#include "llvm/IR/Function.h"
3637
#include "llvm/IR/IRBuilder.h"
38+
#include "llvm/IR/InstrTypes.h"
3739
#include "llvm/IR/Instructions.h"
3840
#include "llvm/IR/IntrinsicInst.h"
3941
#include "llvm/IR/MatrixBuilder.h"
@@ -232,6 +234,32 @@ static bool isUniformShape(Value *V) {
232234
if (I->isBinaryOp())
233235
return true;
234236

237+
if (auto *Cast = dyn_cast<CastInst>(V))
238+
switch (Cast->getOpcode()) {
239+
case llvm::Instruction::Trunc:
240+
case llvm::Instruction::ZExt:
241+
case llvm::Instruction::SExt:
242+
case llvm::Instruction::FPToUI:
243+
case llvm::Instruction::FPToSI:
244+
case llvm::Instruction::UIToFP:
245+
case llvm::Instruction::SIToFP:
246+
case llvm::Instruction::FPTrunc:
247+
case llvm::Instruction::FPExt:
248+
return true;
249+
case llvm::Instruction::AddrSpaceCast:
250+
case CastInst::PtrToInt:
251+
case CastInst::IntToPtr:
252+
return false;
253+
case CastInst::BitCast: {
254+
if (auto *SrcVTy = dyn_cast<FixedVectorType>(Cast->getSrcTy()))
255+
if (auto *DestVTy = dyn_cast<FixedVectorType>(Cast->getDestTy()))
256+
return SrcVTy->getNumElements() == DestVTy->getNumElements();
257+
return false;
258+
}
259+
case llvm::Instruction::CastOpsEnd:
260+
llvm_unreachable("not an actual cast op");
261+
}
262+
235263
switch (I->getOpcode()) {
236264
case Instruction::FNeg:
237265
return true;
@@ -1066,9 +1094,11 @@ class LowerMatrixIntrinsics {
10661094
Value *Op2;
10671095
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
10681096
Changed |= VisitBinaryOperator(BinOp);
1069-
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1097+
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
10701098
Changed |= VisitUnaryOperator(UnOp);
1071-
if (match(Inst, m_Load(m_Value(Op1))))
1099+
else if (auto *Cast = dyn_cast<CastInst>(Inst))
1100+
Changed |= VisitCastInstruction(Cast);
1101+
else if (match(Inst, m_Load(m_Value(Op1))))
10721102
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
10731103
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
10741104
Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
@@ -2198,6 +2228,37 @@ class LowerMatrixIntrinsics {
21982228
return true;
21992229
}
22002230

2231+
/// Lower cast instructions, if shape information is available.
2232+
bool VisitCastInstruction(CastInst *Inst) {
2233+
auto I = ShapeMap.find(Inst);
2234+
if (I == ShapeMap.end())
2235+
return false;
2236+
2237+
Value *Op = Inst->getOperand(0);
2238+
2239+
IRBuilder<> Builder(Inst);
2240+
ShapeInfo &Shape = I->second;
2241+
2242+
MatrixTy Result;
2243+
MatrixTy M = getMatrix(Op, Shape, Builder);
2244+
2245+
Builder.setFastMathFlags(getFastMathFlags(Inst));
2246+
2247+
auto *OrigVTy = cast<VectorType>(Inst->getType());
2248+
auto *NewVTy = VectorType::get(OrigVTy->getElementType(),
2249+
ElementCount::getFixed(M.getStride()));
2250+
2251+
for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2252+
Result.addVector(
2253+
Builder.CreateCast(Inst->getOpcode(), M.getVector(I), NewVTy));
2254+
2255+
finalizeLowering(Inst,
2256+
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2257+
Result.getNumVectors()),
2258+
Builder);
2259+
return true;
2260+
}
2261+
22012262
/// Helper to linearize a matrix expression tree into a string. Currently
22022263
/// matrix expressions are linarized by starting at an expression leaf and
22032264
/// linearizing bottom up.
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
3+
4+
define void @fneg_2x2(ptr %in, ptr %out) {
5+
; CHECK-LABEL: @fneg_2x2(
6+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
7+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
8+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
9+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]]
10+
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]]
11+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
12+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
13+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
14+
; CHECK-NEXT: ret void
15+
;
16+
%inv = load <4 x float>, ptr %in
17+
%op = fneg <4 x float> %inv
18+
%opt = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
19+
%optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
20+
store <4 x float> %optt, ptr %out
21+
ret void
22+
}
23+
24+
define void @trunc_2x2(ptr %in, ptr %out) {
25+
; CHECK-LABEL: @trunc_2x2(
26+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
27+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
28+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
29+
; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32>
30+
; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32>
31+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
32+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
33+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
34+
; CHECK-NEXT: ret void
35+
;
36+
%inv = load <4 x i64>, ptr %in
37+
%op = trunc <4 x i64> %inv to <4 x i32>
38+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
39+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
40+
store <4 x i32> %optt, ptr %out
41+
ret void
42+
}
43+
44+
define void @zext_2x2(ptr %in, ptr %out) {
45+
; CHECK-LABEL: @zext_2x2(
46+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i16>, ptr [[IN:%.*]], align 8
47+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i16, ptr [[IN]], i64 2
48+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4
49+
; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32>
50+
; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32>
51+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
52+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
53+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
54+
; CHECK-NEXT: ret void
55+
;
56+
%inv = load <4 x i16>, ptr %in
57+
%op = zext <4 x i16> %inv to <4 x i32>
58+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
59+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
60+
store <4 x i32> %optt, ptr %out
61+
ret void
62+
}
63+
64+
define void @sext_2x2(ptr %in, ptr %out) {
65+
; CHECK-LABEL: @sext_2x2(
66+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 4
67+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, ptr [[IN]], i64 2
68+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2
69+
; CHECK-NEXT: [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16>
70+
; CHECK-NEXT: [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16>
71+
; CHECK-NEXT: store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 8
72+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2
73+
; CHECK-NEXT: store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 4
74+
; CHECK-NEXT: ret void
75+
;
76+
%inv = load <4 x i8>, ptr %in
77+
%op = sext <4 x i8> %inv to <4 x i16>
78+
%opt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %op, i32 2, i32 2)
79+
%optt = call <4 x i16> @llvm.matrix.transpose(<4 x i16> %opt, i32 2, i32 2)
80+
store <4 x i16> %optt, ptr %out
81+
ret void
82+
}
83+
84+
define void @fptoui_2x2(ptr %in, ptr %out) {
85+
; CHECK-LABEL: @fptoui_2x2(
86+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
87+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
88+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
89+
; CHECK-NEXT: [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32>
90+
; CHECK-NEXT: [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32>
91+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
92+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
93+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
94+
; CHECK-NEXT: ret void
95+
;
96+
%inv = load <4 x float>, ptr %in
97+
%op = fptoui <4 x float> %inv to <4 x i32>
98+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
99+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
100+
store <4 x i32> %optt, ptr %out
101+
ret void
102+
}
103+
104+
define void @fptosi_2x2(ptr %in, ptr %out) {
105+
; CHECK-LABEL: @fptosi_2x2(
106+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
107+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
108+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
109+
; CHECK-NEXT: [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32>
110+
; CHECK-NEXT: [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32>
111+
; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 16
112+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2
113+
; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 8
114+
; CHECK-NEXT: ret void
115+
;
116+
%inv = load <4 x float>, ptr %in
117+
%op = fptosi <4 x float> %inv to <4 x i32>
118+
%opt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %op, i32 2, i32 2)
119+
%optt = call <4 x i32> @llvm.matrix.transpose(<4 x i32> %opt, i32 2, i32 2)
120+
store <4 x i32> %optt, ptr %out
121+
ret void
122+
}
123+
124+
define void @uitofp_2x2(ptr %in, ptr %out) {
125+
; CHECK-LABEL: @uitofp_2x2(
126+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
127+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
128+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
129+
; CHECK-NEXT: [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double>
130+
; CHECK-NEXT: [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
131+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
132+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
133+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
134+
; CHECK-NEXT: ret void
135+
;
136+
%inv = load <4 x i64>, ptr %in
137+
%op = uitofp <4 x i64> %inv to <4 x double>
138+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
139+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
140+
store <4 x double> %optt, ptr %out
141+
ret void
142+
}
143+
144+
define void @sitofp_2x2(ptr %in, ptr %out) {
145+
; CHECK-LABEL: @sitofp_2x2(
146+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32
147+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2
148+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16
149+
; CHECK-NEXT: [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double>
150+
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double>
151+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
152+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
153+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
154+
; CHECK-NEXT: ret void
155+
;
156+
%inv = load <4 x i64>, ptr %in
157+
%op = sitofp <4 x i64> %inv to <4 x double>
158+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
159+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
160+
store <4 x double> %optt, ptr %out
161+
ret void
162+
}
163+
164+
define void @fptrunc_2x2(ptr %in, ptr %out) {
165+
; CHECK-LABEL: @fptrunc_2x2(
166+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
167+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
168+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
169+
; CHECK-NEXT: [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float>
170+
; CHECK-NEXT: [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float>
171+
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
172+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
173+
; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
174+
; CHECK-NEXT: ret void
175+
;
176+
%inv = load <4 x double>, ptr %in
177+
%op = fptrunc nnan <4 x double> %inv to <4 x float>
178+
%opt = call <4 x float> @llvm.matrix.transpose(<4 x float> %op, i32 2, i32 2)
179+
%optt = call <4 x float> @llvm.matrix.transpose(<4 x float> %opt, i32 2, i32 2)
180+
store <4 x float> %optt, ptr %out
181+
ret void
182+
}
183+
184+
define void @fpext_2x2(ptr %in, ptr %out) {
185+
; CHECK-LABEL: @fpext_2x2(
186+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
187+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
188+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
189+
; CHECK-NEXT: [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double>
190+
; CHECK-NEXT: [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double>
191+
; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 32
192+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2
193+
; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 16
194+
; CHECK-NEXT: ret void
195+
;
196+
%inv = load <4 x float>, ptr %in
197+
%op = fpext <4 x float> %inv to <4 x double>
198+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
199+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
200+
store <4 x double> %optt, ptr %out
201+
ret void
202+
}
203+
204+
define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) {
205+
; CHECK-LABEL: @bitcast_2x2_v4f64_to_v4i64(
206+
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32
207+
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2
208+
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16
209+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x double> [[COL_LOAD]] to <2 x i64>
210+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x double> [[COL_LOAD1]] to <2 x i64>
211+
; CHECK-NEXT: store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 32
212+
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i64, ptr [[OUT]], i64 2
213+
; CHECK-NEXT: store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 16
214+
; CHECK-NEXT: ret void
215+
;
216+
%inv = load <4 x double>, ptr %in
217+
%op = bitcast <4 x double> %inv to <4 x i64>
218+
%opt = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %op, i32 2, i32 2)
219+
%optt = call <4 x i64> @llvm.matrix.transpose(<4 x i64> %opt, i32 2, i32 2)
220+
store <4 x i64> %optt, ptr %out
221+
ret void
222+
}
223+
224+
define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) {
225+
; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64(
226+
; CHECK-NEXT: [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4
227+
; CHECK-NEXT: [[OP:%.*]] = bitcast i256 [[INV]] to <4 x double>
228+
; CHECK-NEXT: store <4 x double> [[OP]], ptr [[OUT:%.*]], align 32
229+
; CHECK-NEXT: ret void
230+
;
231+
%inv = load i256, ptr %in
232+
%op = bitcast i256 %inv to <4 x double>
233+
%opt = call <4 x double> @llvm.matrix.transpose(<4 x double> %op, i32 2, i32 2)
234+
%optt = call <4 x double> @llvm.matrix.transpose(<4 x double> %opt, i32 2, i32 2)
235+
store <4 x double> %optt, ptr %out
236+
ret void
237+
}

0 commit comments

Comments
 (0)