Skip to content

Commit

Permalink
Add support ofr vecmat/matvec in Set_Encoding and Materialize_Encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
NatashaKnk committed Oct 20, 2023
1 parent 3323519 commit 0f539cb
Show file tree
Hide file tree
Showing 7 changed files with 1,017 additions and 165 deletions.
169 changes: 73 additions & 96 deletions compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,29 @@ static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc,

namespace {

/// Rewrites the matmul op to work on tensors with encoding. Optionally
/// Rewrites contraction ops to work on tensors with encoding. Optionally
/// also pads the operands.
struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
SetMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<linalg::MatmulOp>(context, benefit) {}

LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
/// Currently works on matmul, batch_matmul, vecmat, matvec and batch_matvec.
struct SetContractionOpEncoding
: public OpInterfaceRewritePattern<linalg::ContractionOpInterface> {
using OpInterfaceRewritePattern<
linalg::ContractionOpInterface>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::ContractionOpInterface op,
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || !linalgOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();

auto inputs = linalgOp.getDpsInputs();
auto outputs = linalgOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
return type && type.getEncoding();
};
if (llvm::any_of(inputs, hasEncoding) ||
llvm::any_of(outputs, hasEncoding)) {
return failure();
return rewriter.notifyMatchFailure(op,
"input/output already has encoding");
}

Value origLhs = inputs[0];
Expand All @@ -182,85 +186,26 @@ struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
Type outElemType = getElemType(origOut);

if (!lhsElemType || !rhsElemType || !outElemType) {
return failure();
return rewriter.notifyMatchFailure(op, "invalid input/output");
}

IREE::LinalgExt::EncodingUser user = IREE::LinalgExt::EncodingUser::MATMUL;
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Value encodedRhs =
padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS, operandTypes);
Value encodedOut =
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);

Value matmulTiled = rewriter
.create<linalg::MatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);

// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> origOutSizes =
IREE::LinalgExt::getDims(rewriter, loc, origOut);
if (failed(origOutSizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
IREE::LinalgExt::EncodingUser user;
if (op.isRowMajorMatmul() || op.isColumnMajorMatmul()) {
user = IREE::LinalgExt::EncodingUser::MATMUL;
} else if (op.isRowMajorBatchMatmul()) {
user = IREE::LinalgExt::EncodingUser::BATCH_MATMUL;
} else if (op.isVecmat()) {
user = IREE::LinalgExt::EncodingUser::VECMAT;
} else if (op.isMatvec()) {
user = IREE::LinalgExt::EncodingUser::MATVEC;
} else if (op.isBatchMatvec()) {
user = IREE::LinalgExt::EncodingUser::BATCH_MATVEC;
} else {
return rewriter.notifyMatchFailure(op, "unsupported contraction op");
}

Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled,
origOutSizes.value());

rewriter.replaceOp(matmulOp, result);
return success();
}
};

struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {
SetBatchMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<linalg::BatchMatmulOp>(context, benefit) {}

LogicalResult matchAndRewrite(linalg::BatchMatmulOp matmulOp,
PatternRewriter &rewriter) const override {
if (!matmulOp.hasTensorSemantics())
return failure();
auto inputs = matmulOp.getDpsInputs();
auto outputs = matmulOp.getDpsInits();
auto hasEncoding = [](Value operand) -> bool {
auto type = llvm::dyn_cast<RankedTensorType>(operand.getType());
return type && type.getEncoding();
};
if (llvm::any_of(inputs, hasEncoding) ||
llvm::any_of(outputs, hasEncoding)) {
return failure();
}

Value origLhs = inputs[0];
Value origRhs = inputs[1];
Value origOut = outputs[0];

auto getElemType = [](Value v) -> Type {
if (auto tensorType = llvm::dyn_cast<RankedTensorType>(v.getType())) {
return tensorType.getElementType();
}
return {};
};
Type lhsElemType = getElemType(origLhs);
Type rhsElemType = getElemType(origRhs);
Type outElemType = getElemType(origOut);

if (!lhsElemType || !rhsElemType || !outElemType) {
return failure();
}

IREE::LinalgExt::EncodingUser user =
IREE::LinalgExt::EncodingUser::BATCH_MATMUL;
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Location loc = op.getLoc();
TypeRange operandTypes = op->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Expand All @@ -271,24 +216,56 @@ struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);

Value matmulTiled = rewriter
.create<linalg::BatchMatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
Value opTiled;
switch (user) {
case IREE::LinalgExt::EncodingUser::MATMUL:
opTiled = rewriter
.create<linalg::MatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;

case IREE::LinalgExt::EncodingUser::BATCH_MATMUL:
opTiled = rewriter
.create<linalg::BatchMatmulOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::VECMAT:
opTiled = rewriter
.create<linalg::VecmatOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::MATVEC:
opTiled = rewriter
.create<linalg::MatvecOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
case IREE::LinalgExt::EncodingUser::BATCH_MATVEC:
opTiled = rewriter
.create<linalg::BatchMatvecOp>(
loc, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs}, encodedOut)
.getResult(0);
break;
}

// Sizes are computed by original output size.
FailureOr<SmallVector<OpFoldResult>> origOutSizes =
IREE::LinalgExt::getDims(rewriter, loc, origOut);
if (failed(origOutSizes)) {
return rewriter.notifyMatchFailure(matmulOp,
"failed to get shape of result");
return rewriter.notifyMatchFailure(op, "failed to get shape of result");
}

Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled,
Value result = unsetEncodingAndExtractSlice(rewriter, loc, opTiled,
origOutSizes.value());

rewriter.replaceOp(matmulOp, result);
rewriter.replaceOp(linalgOp, result);
return success();
}
};
Expand Down Expand Up @@ -332,7 +309,7 @@ void SetEncodingPass::runOnOperation() {
MLIRContext *context = &getContext();
{
RewritePatternSet patterns(context);
patterns.insert<SetBatchMatmulEncoding, SetMatmulEncoding>(context);
patterns.insert<SetContractionOpEncoding>(context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
Expand Down
Loading

0 comments on commit 0f539cb

Please sign in to comment.