Skip to content

[Matrix] Hoist IRBuilder<> out of Visit* functions. NFC #144369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025

Conversation

jroelofs
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144369.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+21-27)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 1e37f40fa9d52..ece0bb56fff01 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1146,24 +1146,24 @@ class LowerMatrixIntrinsics {
       Value *Op1;
       Value *Op2;
       MatrixTy Result;
+      IRBuilder<> Builder(Inst);
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
-        Result = VisitBinaryOperator(BinOp, SI);
+        Result = VisitBinaryOperator(BinOp, SI, Builder);
       else if (auto *Cast = dyn_cast<CastInst>(Inst))
-        Result = VisitCastInstruction(Cast, SI);
+        Result = VisitCastInstruction(Cast, SI, Builder);
       else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
-        Result = VisitUnaryOperator(UnOp, SI);
+        Result = VisitUnaryOperator(UnOp, SI, Builder);
       else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
-        Result = VisitIntrinsicInst(Intr, SI);
+        Result = VisitIntrinsicInst(Intr, SI, Builder);
       else if (auto *Select = dyn_cast<SelectInst>(Inst))
-        Result = VisitSelectInst(Select, SI);
+        Result = VisitSelectInst(Select, SI, Builder);
       else if (match(Inst, m_Load(m_Value(Op1))))
-        Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
+        Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
-        Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
+        Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
       else
         continue;
 
-      IRBuilder<> Builder(Inst);
       finalizeLowering(Inst, Result, Builder);
       Changed = true;
     }
@@ -1204,7 +1204,8 @@ class LowerMatrixIntrinsics {
   }
 
   /// Replace intrinsic calls.
-  MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI,
+                              IRBuilder<> &Builder) {
     assert(Inst->getCalledFunction() &&
            Inst->getCalledFunction()->isIntrinsic());
 
@@ -1219,7 +1220,6 @@ class LowerMatrixIntrinsics {
       return LowerColumnMajorStore(Inst);
     case Intrinsic::abs:
     case Intrinsic::fabs: {
-      IRBuilder<> Builder(Inst);
       MatrixTy Result;
       MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
       Builder.setFastMathFlags(getFastMathFlags(Inst));
@@ -1298,7 +1298,6 @@ class LowerMatrixIntrinsics {
                       ShapeInfo MatrixShape, Value *I, Value *J,
                       ShapeInfo ResultShape, Type *EltTy,
                       IRBuilder<> &Builder) {
-
     Value *Offset = Builder.CreateAdd(
         Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
 
@@ -2228,26 +2227,24 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower load instructions.
-  MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
-    IRBuilder<> Builder(Inst);
+  MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
+                     IRBuilder<> &Builder) {
     return LowerLoad(Inst, Ptr, Inst->getAlign(),
                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
   MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
-                      Value *Ptr) {
-    IRBuilder<> Builder(Inst);
+                      Value *Ptr, IRBuilder<> &Builder) {
     return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
                       Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
   /// Lower binary operators.
-  MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI,
+                               IRBuilder<> &Builder) {
     Value *Lhs = Inst->getOperand(0);
     Value *Rhs = Inst->getOperand(1);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy A = getMatrix(Lhs, SI, Builder);
     MatrixTy B = getMatrix(Rhs, SI, Builder);
@@ -2265,11 +2262,10 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower unary operators.
-  MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI,
+                              IRBuilder<> &Builder) {
     Value *Op = Inst->getOperand(0);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy M = getMatrix(Op, SI, Builder);
 
@@ -2293,11 +2289,10 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower cast instructions.
-  MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
+  MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape,
+                                IRBuilder<> &Builder) {
     Value *Op = Inst->getOperand(0);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy M = getMatrix(Op, Shape, Builder);
 
@@ -2315,13 +2310,12 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower selects.
-  MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
+  MatrixTy VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape,
+                           IRBuilder<> &Builder) {
     Value *Cond = Inst->getOperand(0);
     Value *OpA = Inst->getOperand(1);
     Value *OpB = Inst->getOperand(2);
 
-    IRBuilder<> Builder(Inst);
-
     MatrixTy Result;
     MatrixTy A = getMatrix(OpA, Shape, Builder);
     MatrixTy B = getMatrix(OpB, Shape, Builder);

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@jroelofs jroelofs merged commit 8ed43c4 into llvm:main Jun 16, 2025
9 checks passed
@jroelofs jroelofs deleted the jroelofs/lower-matrix-builder branch June 16, 2025 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants