Skip to content

[Matrix] Propagate shape information through (f)abs insts #141704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 9, 2025
Merged
58 changes: 46 additions & 12 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ static bool isUniformShape(Value *V) {
if (I->isBinaryOp())
return true;

if (auto *II = dyn_cast<IntrinsicInst>(V))
switch (II->getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::fabs:
return true;
default:
return false;
}

switch (I->getOpcode()) {
case Instruction::FNeg:
return true;
Expand Down Expand Up @@ -618,7 +627,7 @@ class LowerMatrixIntrinsics {
case Intrinsic::matrix_column_major_store:
return true;
default:
return false;
return isUniformShape(II);
}
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}
Expand Down Expand Up @@ -1064,8 +1073,8 @@ class LowerMatrixIntrinsics {
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
VisitIntrinsicInst(Intr, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Expand Down Expand Up @@ -1111,23 +1120,48 @@ class LowerMatrixIntrinsics {
}

/// Replace intrinsic calls.
void VisitCallInst(CallInst *Inst) {
assert(Inst->getCalledFunction() &&
Inst->getCalledFunction()->isIntrinsic());

switch (Inst->getCalledFunction()->getIntrinsicID()) {
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
LowerMultiply(Inst);
break;
return;
case Intrinsic::matrix_transpose:
LowerTranspose(Inst);
break;
return;
case Intrinsic::matrix_column_major_load:
LowerColumnMajorLoad(Inst);
break;
return;
case Intrinsic::matrix_column_major_store:
LowerColumnMajorStore(Inst);
break;
return;
case Intrinsic::abs:
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);
MatrixTy Result;
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));

for (auto &Vector : M.vectors()) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::abs:
Result.addVector(Builder.CreateBinaryIntrinsic(Intrinsic::abs, Vector,
Inst->getOperand(1)));
continue;
case Intrinsic::fabs:
Result.addVector(
Builder.CreateUnaryIntrinsic(Inst->getIntrinsicID(), Vector));
continue;
default:
llvm_unreachable("unexpected intrinsic");
}
}

finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
return;
}
default:
llvm_unreachable(
"only intrinsics supporting shape info should be seen here");
Expand Down
Loading
Loading