Skip to content

[Matrix] Propagate shape information through (f)abs insts #141704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 9, 2025
Merged
57 changes: 56 additions & 1 deletion llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
if (I->isBinaryOp())
return true;

if (auto *II = dyn_cast<IntrinsicInst>(V))
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs:
return true;
default:
return false;
}

switch (I->getOpcode()) {
case Instruction::FNeg:
return true;
Expand Down Expand Up @@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
return isUniformShape(II);
}
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}
Expand Down Expand Up @@ -1124,6 +1133,9 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
LowerColumnMajorStore(Inst);
break;
case Intrinsic::abs:
case Intrinsic::fabs:
return VisitUniformIntrinsic(cast<IntrinsicInst>(Inst));
default:
return false;
}
Expand Down Expand Up @@ -2194,6 +2206,49 @@ class LowerMatrixIntrinsics {
return true;
}

/// Lower uniform shape intrinsics, if shape information is available.
bool VisitUniformIntrinsic(IntrinsicInst *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

MatrixTy Result;

switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs: {
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);

Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors())
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
break;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
break;
default:
llvm_unreachable("unexpected intrinsic");
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return true;
}
default:
llvm_unreachable("unexpected intrinsic");
}
}

/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
Expand Down
Loading
Loading