Skip to content

[mlir][linalg] Fix padding transform and update transform padding op. #144354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Jun 16, 2025

This patch makes the following changes:

  • Add a ValueRange typeDynDims argument to linalg::makeComposedPadHighOp, allowing to pad a tensor with dynamic dimensions using tensor::createPadHighOp.

  • Add a DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> padToSizeOf; option to LinalgPaddingOptions. This option allows setting the size to use when padding a dimension of an operand, allowing to pad operands even in the case they don't have a constant upper bounding box.

  • Add the LinalgPaddingOptions::computeConstantUpperShapeBounds method, to compute padToSizeOf with the constant upper bounding box. Allowing to preserve existing behavior.

  • Add a use_prescribed_tensor_shapes option to transform.structured.pad. If set to false then the constant upper bounding box will be used, preserving existing behavior (this is the default). If set to true, then tensor.dim will be used as dimensions to compute the padding.

  • This pass also changes the behavior for computing the padded shape linalg::rewriteAsPaddedOp.

    • Using the newly added options in LinalgPaddingOptions.
  • Finally, this patch adds tests for verifying the behavior.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tensor

Author: Fabian Mora (fabianmcg)

Changes

NOTE: This PR is not fully ready for review. Opening to test CI. Will update the commit doc shortly and add a test for the new case of transform op.


Patch is 32.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144354.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+16)
  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+10-7)
  • (modified) mlir/include/mlir/Dialect/Tensor/Utils/Utils.h (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+193-57)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+16-8)
  • (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+1-2)
  • (modified) mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir (-168)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..bf56b633d2872 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1134,7 +1134,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
          DefaultValuedAttr<
           TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
           "{}">:$transpose_paddings,
-         DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
+         DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op,
+         UnitProp:$use_default_tensor_shapes);
   let results = (outs TransformHandleTypeInterface:$padded,
                       TransformHandleTypeInterface:$pad,
                       TransformHandleTypeInterface:$copy);
@@ -1142,6 +1143,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   let assemblyFormat = [{
     $target 
     (`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
+    (`use_default_tensor_shapes` $use_default_tensor_shapes^)?
     attr-dict
     `:` functional-type(operands, results)
   }];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..7004d98c22475 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -295,6 +295,22 @@ struct LinalgPaddingOptions {
     padToMultipleOf.emplace(m.begin(), m.end());
     return *this;
   }
+  /// A mapping between an operand and shape dim, and a size for a padding
+  /// dimension. Each size is expected to be greater or equal than the
+  /// corresponding shape dim. If no size is provided for a padding dim, then
+  /// the corresponding tensor size will be used when padding.
+  DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> padToSizeOf;
+  LinalgPaddingOptions &setPadToSizeOf(unsigned operandIndex, unsigned dimIndex,
+                                       OpFoldResult size) {
+    assert(size && "expected non-null size");
+    padToSizeOf[{operandIndex, dimIndex}] = size;
+    return *this;
+  }
+  /// Populates the `padToSizeOf` map with constant upper bounds for each padded
+  /// dim and operand of `opToPad`. Returns failure if any of the sizes cannot
+  /// be computed.
+  LogicalResult computeConstantUpperShapeBounds(linalg::LinalgOp opToPad);
+
   /// A flag for every operand to mark the PadOp as nofold which enables
   /// packing for statically shaped operands.
   SmallVector<bool> nofoldFlags;
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 80aa034d2199d..fc151d02ceef6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType);
 /// Check if iterator type  has "reduction" semantics.
 bool isReductionIterator(utils::IteratorType iteratorType);
 
-/// Create a tensor::PadOp that pads `source` to the size of the statically
-/// sized `type` whose static sizes are assumed to be greater than the dynamic
-/// `source` size. The padding introduces trailing `pad` values until the
-/// target size is met. If `source` is defined by one or more LinalgOps that
-/// have been padded with the same value and sizes, return their padded result
-/// instead of creating a tensor::PadOp.
+/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes
+/// are assumed to be greater than the dynamic `source` size. If `typeDynDims`
+/// is specified, then it must contain the sizes of all the dynamic dimensions
+/// in order of appearance in `type`, otherwise the function will pad those
+/// values to `0`. The padding introduces trailing `pad` values until the target
+/// size is met. If `source` is defined by one or more LinalgOps that have been
+/// padded with the same  value and sizes, return their padded result instead of
+/// creating a tensor::PadOp.
 ///
 /// Example:
 /// ```
@@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType);
 /// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst }
 /// ```
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
-                            Value source, Value pad, bool nofold);
+                            Value source, Value padding, bool nofold,
+                            ValueRange typeDynDims = std::nullopt);
 
 /// Returns GenericOp that copies an n-D memref. Unlike the current
 /// implementation of memref::CopyOp, this op can further tile, lower to loops
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 1a4733df3f187..a1ce4e252c2f4 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -30,7 +30,7 @@ namespace tensor {
 // for _static_ dimensions.
 PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
                       bool nofold, Location loc, OpBuilder &builder,
-                      SmallVector<Value> dynOutDims = {});
+                      ValueRange dynOutDims = std::nullopt);
 
 // Creates dim ops for each dynamic dimension of the ranked tensor argument and
 // returns these as values.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..45d83b9a8d3c8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2051,6 +2051,15 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     } else {
       llvm_unreachable("unsupported copy_back op");
     }
+    if (!getUseDefaultTensorShapes()) {
+      if (failed(options.computeConstantUpperShapeBounds(linalgTarget))) {
+        auto diag =
+            emitSilenceableError()
+            << "could not compute upper constant bounds for padded dims";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+    }
 
     SmallVector<Value> replacements;
     SmallVector<tensor::PadOp> newPadOps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 9a685f6dc96ac..475f338ffbc90 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -8,10 +8,12 @@
 
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 #define DEBUG_TYPE "linalg-padding"
@@ -22,69 +24,147 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 #define DBGSNL() (llvm::dbgs() << "\n")
 
-/// Compute the padded shape of the given operand. The operand is padded to a
-/// static bounding box according to the specified padding options.
-static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
-                                        OpOperand *opOperand,
-                                        const LinalgPaddingOptions &options,
-                                        SmallVector<int64_t> &paddedShape,
-                                        bool &alreadyHasRequestedShape) {
+namespace {
+/// Helper class for storing padding information.
+struct PaddingInfo {
+  PaddingInfo(int64_t padToMultipleOf = 1, OpFoldResult size = {})
+      : padToMultipleOf(padToMultipleOf), size(size) {}
+  bool isTrivial() const { return padToMultipleOf == 1 && size.isNull(); }
+  /// Pad the tensor to a multiple of.
+  int64_t padToMultipleOf = 1;
+  /// The size used for padding.
+  OpFoldResult size = {};
+};
+
+/// Helper class for storing and computing the padded shape.
+struct PaddedShape {
+  /// Initializes the shape information and returns whether the shape of the
+  /// operand will change.
+  bool initialize(linalg::LinalgOp opToPad, OpOperand *opOperand,
+                  const LinalgPaddingOptions &options);
+
+  /// Computs the padded shape.
+  void computePadding(OpBuilder &builder, Value operand);
+
+  /// Returns the new tensor type.
+  RankedTensorType getType(Type elemTy) {
+    return RankedTensorType::get(shape, elemTy);
+  }
+
+  /// Return the dynamic dimensions of the shape.
+  ValueRange getDynamicDims() { return dynDims; }
+
+private:
+  SmallVector<int64_t> shape;
+  SmallVector<Value> dynDims;
+  DenseMap<int64_t, PaddingInfo> dimToInfo;
+};
+} // namespace
+
+bool PaddedShape::initialize(linalg::LinalgOp opToPad, OpOperand *opOperand,
+                             const LinalgPaddingOptions &options) {
   AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
-  ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
+
+  // Initialize the padded shape.
+  llvm::append_range(shape, opToPad.getShape(opOperand));
+
+  // Return early if there's no padding involved.
+  if (!options.padToMultipleOf && options.padToSizeOf.empty())
+    return true;
 
   // Collect the shape dimensions that are a function of "paddingDimensions",
   // along with the multiple that they should be padded to ("1" if none).
-  alreadyHasRequestedShape = true;
-  DenseMap<int64_t, int64_t> shapeDimToMultiple;
-  for (const auto &dimEn : enumerate(options.paddingDimensions)) {
-    for (const auto &en : enumerate(indexingMap.getResults())) {
-      if (en.value().isFunctionOfDim(dimEn.value())) {
-        int64_t dimSize = shape[en.index()];
+  bool alreadyHasRequestedShape = true;
+  for (const auto [shapeIndex, shapeExpr] :
+       enumerate(indexingMap.getResults())) {
+    PaddingInfo paddingInfo;
+
+    // Flag indicating whether the current dim in the operand has to be padded.
+    bool isPaddedDim = false;
+
+    // Construct the padding info according to the options.
+    for (const auto [dimIndex, dim] : enumerate(options.paddingDimensions)) {
+      if (shapeExpr.isFunctionOfDim(dim)) {
+        isPaddedDim = true;
         if (options.padToMultipleOf.has_value()) {
-          shapeDimToMultiple[en.index()] =
-              (*options.padToMultipleOf)[dimEn.index()];
-        } else {
-          shapeDimToMultiple[en.index()] = 1;
-        }
-        if (ShapedType::isDynamic(dimSize)) {
-          alreadyHasRequestedShape = false;
-        } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
-          alreadyHasRequestedShape = false;
+          // We use the least common multiple as multiple dim iterators can
+          // appear in a shape dim, for example:
+          // `affine_map<(i, j, k) -> (i, k, i + j)>` could impose more than one
+          // multiple of constraint on the last dim of the shape.
+          paddingInfo.padToMultipleOf =
+              std::lcm((*options.padToMultipleOf)[dimIndex],
+                       paddingInfo.padToMultipleOf);
         }
       }
     }
+
+    // If the dimension is not being padded, continue.
+    if (!isPaddedDim)
+      continue;
+
+    // Check if the dim is being padded to a specified size.
+    if (auto it = options.padToSizeOf.find(
+            {opOperand->getOperandNumber(), shapeIndex});
+        it != options.padToSizeOf.end()) {
+      paddingInfo.size = it->second;
+      assert(paddingInfo.size && "expected non-null `OpFoldResult`");
+    }
+
+    int64_t dimSize = shape[shapeIndex];
+
+    // Skip if the padding information is trivial. Note that dynamic
+    // dimensions never have trivial padding information.
+    if (paddingInfo.isTrivial() && !ShapedType::isDynamic(dimSize))
+      continue;
+
+    // Set the padding info.
+    dimToInfo[shapeIndex] = paddingInfo;
+    if (ShapedType::isDynamic(dimSize) ||
+        dimSize % paddingInfo.padToMultipleOf != 0) {
+      alreadyHasRequestedShape = false;
+    }
   }
 
-  // Helper function to round a number up to a given multiple.
-  auto ceil = [](int64_t val, int64_t multiple) {
-    return ((val + multiple - 1) / multiple) * multiple;
-  };
+  return alreadyHasRequestedShape;
+}
+
+void PaddedShape::computePadding(OpBuilder &builder, Value operand) {
+  Location loc = operand.getLoc();
+  AffineExpr sizeSym = builder.getAffineSymbolExpr(0);
 
-  // Upper bound the sizes to obtain a static bounding box.
-  paddedShape.assign(shape.begin(), shape.end());
-  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+  // Compute the padding for each dimension.
+  for (auto &&[i, dim] : llvm::enumerate(shape)) {
     LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
+
     // Skip dimensions that do not require padding.
-    if (!shapeDimToMultiple.contains(i)) {
+    if (!dimToInfo.contains(i)) {
       LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
+      if (ShapedType::isDynamic(dim)) {
+        dynDims.push_back(
+            cast<Value>(tensor::getMixedSize(builder, loc, operand, i)));
+      }
       continue;
     }
-    // Otherwise, try to compute a constant upper bound for the size value.
-    FailureOr<int64_t> upperBound =
-        ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB,
-            {opOperand->get(),
-             /*dim=*/i},
-            /*stopCondition=*/nullptr, /*closedUB=*/true);
-    if (failed(upperBound)) {
-      LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
-      return failure();
+    PaddingInfo paddingInfo = dimToInfo[i];
+    OpFoldResult size = paddingInfo.size;
+    // Get the tensor dim size if none was provided.
+    if (size.isNull())
+      size = tensor::getMixedSize(builder, loc, operand, i);
+
+    // Compute the padded size to be a multiple of `padToMultipleOf`.
+    AffineExpr szExpr = (sizeSym).ceilDiv(paddingInfo.padToMultipleOf) *
+                        paddingInfo.padToMultipleOf;
+    OpFoldResult paddedSize =
+        affine::makeComposedFoldedAffineApply(builder, loc, szExpr, size);
+    assert(paddedSize && "invalid arguments to affine apply");
+    if (auto cstSzAttr = dyn_cast<Attribute>(paddedSize)) {
+      dim = cast<IntegerAttr>(cstSzAttr).getValue().getZExtValue();
+    } else {
+      dim = ShapedType::kDynamic;
+      dynDims.push_back(cast<Value>(paddedSize));
     }
-    paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
-    LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
+    LLVM_DEBUG(DBGS() << "----new dim size: " << paddedSize << "\n");
   }
-
-  return success();
 }
 
 /// Pad the `opOperand` in the "paddingDimensions" using the padding value and
@@ -107,13 +187,9 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
        options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
       "invalid number of elements in padToMultipleOf");
 
-  // Compute padded shape.
-  SmallVector<int64_t> paddedShape;
-  bool alreadyHasRequestedShape = false;
-  if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
-                                alreadyHasRequestedShape)))
-    return rewriter.notifyMatchFailure(opToPad,
-                                       "--failed to compute padded shape");
+  // Initialize the padded shape.
+  PaddedShape shape;
+  bool alreadyHasRequestedShape = shape.initialize(opToPad, opOperand, options);
 
   // Return the unpadded operand if padding to a static shape is not needed and
   // if the nofold flag is not set.
@@ -140,13 +216,73 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
         opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
   }
 
-  // Pad the operand to the bounding box defined by `paddedShape`.
-  auto paddedTensorType = RankedTensorType::get(
-      paddedShape, getElementTypeOrSelf(opOperand->get()));
+  // If needed, the padding for each dimension.
+  if (!alreadyHasRequestedShape)
+    shape.computePadding(rewriter, opOperand->get());
+
+  // Compute the new tensor type.
+  RankedTensorType paddedTensorType =
+      shape.getType(getElementTypeOrSelf(opOperand->get()));
   LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
-                    << paddedTensorType);
+                    << paddedTensorType << "\n");
+
+  // Pad the operand to the bounding box defined by `shape`.
   return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
-                               opOperand->get(), paddingValue, nofold);
+                               opOperand->get(), paddingValue, nofold,
+                               shape.getDynamicDims());
+}
+
+LogicalResult LinalgPaddingOptions::computeConstantUpperShapeBounds(
+    linalg::LinalgOp opToPad) {
+  LLVM_DEBUG(DBGS() << "-Computing constant upper bounds for " << opToPad
+                    << "\n");
+  if (!opToPad.hasPureTensorSemantics()) {
+    LLVM_DEBUG(DBGS() << "--FAILURE: op does not have pure tensor semantics"
+                      << "\n");
+    return failure();
+  }
+  Builder builder(opToPad.getContext());
+  // For each operand compute the sizes.
+  for (OpOperand &operand : opToPad->getOpOperands()) {
+    AffineMap indexingMap = opToPad.getMatchingIndexingMap(&operand);
+    LLVM_DEBUG(DBGS() << "--Computing constant upper bounds for operand "
+                      << operand.getOperandNumber() << "\n");
+
+    // Get the size for each dimension.
+    for (const auto [shapeIndex, shapeExpr] :
+         enumerate(indexingMap.getResults())) {
+
+      // Get whether the iterator dimension is being padded.
+      // TODO[c++20]: Remove the `expr = shapeExpr` copy, this was added as
+      // `captured structured bindings are a C++20 extension`.
+      AffineExpr expr = shapeExpr;
+      bool isPaddedDim = llvm::any_of(paddingDimensions, [expr](unsigned dim) {
+        return expr.isFunctionOfDim(dim);
+      });
+      if (!isPaddedDim)
+        continue;
+
+      // Compute the constant upper bound.
+      LLVM_DEBUG(DBGS() << "---compute upper bound size for shape dim "
+                        << shapeIndex << "\n");
+      FailureOr<int64_t> upperBound =
+          ValueBoundsConstraintSet::computeConstantBound(
+              presburger::BoundType::UB,
+              {operand.get(),
+               /*dim=*/static_cast<int64_t>(shapeIndex)},
+              /*stopCondition=*/nullptr, /*closedUB=*/true);
+      if (failed(upperBound)) {
+        LLVM_DEBUG(DBGS() << "---could not compute a bounding box for padding"
+                          << "\n");
+        return failure();
+      }
+
+      // Set the upper bound.
+      setPadToSizeOf(operand.getOperandNumber(), shapeIndex,
+                     builder.getIndexAttr(*upperBound));
+    }
+  }
+  return success();
 }
 
 LogicalResult
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2527d90cfa2e6..209309ddb413a 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -244,11 +244,13 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 }
 
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
-                            Value source, Value pad, bool nofold) {
+                            Value source, Value pad, bool nofold,
+                            ValueRange typeDynDims) {
   // Exit if `source` is not defined by an ExtractSliceOp.
   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
   if (!sliceOp)
-    return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+    return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+                                   typeDynDims);
 
   // Search the `source` use-def chain for padded LinalgOps.
   Value current = sliceOp.getSource();
@@ -264,24 +266,28 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
   // Exit if the search fails to match a tensor::PadOp at the end of...
[truncated]

@fabianmcg fabianmcg marked this pull request as draft June 16, 2025 13:41
@fabianmcg fabianmcg marked this pull request as draft June 16, 2025 13:41
@fabianmcg fabianmcg force-pushed the users/fabian/pr-fix-padding branch from 03078e9 to e59c3ff Compare June 16, 2025 13:43
@fabianmcg fabianmcg marked this pull request as ready for review June 16, 2025 17:04
@fabianmcg fabianmcg force-pushed the users/fabian/pr-fix-padding branch from dc095f4 to 59ade48 Compare June 16, 2025 18:10
@fabianmcg fabianmcg marked this pull request as draft June 16, 2025 22:44
@fabianmcg fabianmcg force-pushed the users/fabian/pr-fix-padding branch from f7da823 to 3e13fc2 Compare June 17, 2025 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants