Skip to content

[Matrix] Propagate shape information through (f)abs insts #141704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 9, 2025
86 changes: 34 additions & 52 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,8 @@ class LowerMatrixIntrinsics {
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
VisitIntrinsicInst(Intr, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Expand Down Expand Up @@ -1120,11 +1120,8 @@ class LowerMatrixIntrinsics {
}

/// Replace intrinsic calls.
void VisitCallInst(CallInst *Inst) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
LowerMultiply(Inst);
break;
Expand All @@ -1138,8 +1135,36 @@ class LowerMatrixIntrinsics {
LowerColumnMajorStore(Inst);
break;
case Intrinsic::abs:
case Intrinsic::fabs:
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);

MatrixTy Result;

MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors()) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return;
}
default:
llvm_unreachable(
"only intrinsics supporting shape info should be seen here");
Expand Down Expand Up @@ -2189,49 +2214,6 @@ class LowerMatrixIntrinsics {
Builder);
}

/// Lower uniform shape intrinsics, if shape information is available.
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;

switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs: {
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors())
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return true;
}
default:
llvm_unreachable("unexpected intrinsic");
}
}

/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
Expand Down
Loading