Skip to content

Commit 62ae876

Browse files
authored
[mlir][tosa] Fix conv op build functions (llvm#126321)
This patch fixes several issues: - buildConvOpWithQuantInfo: call buildConvOpResultTypeInfo to get final output type - buildTransConvOpWithQuantInfo: add input_zp and weight_zp operands remove input_zp/weight_zp attributes - createZeroPointTensor: add getElementTypeOrSelf to get element type just in case remove bad auto-merge lines Change-Id: Idbf88f500ce57a865da4b7be7b7b8bf2ba194b24 Signed-off-by: Tai Ly <tai.ly@arm.com>
1 parent 3706dfe commit 62ae876

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
510510
result.addAttribute("stride", stride);
511511
result.addAttribute("dilation", dilation);
512512
result.addAttribute("acc_type", accType);
513-
result.addTypes(outputType);
513+
Type finalOutputType = outputType;
514+
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
515+
if (quantAttr) {
516+
finalOutputType =
517+
buildConvOpResultTypeInfo(builder, outputType, input, weight);
518+
}
519+
result.addTypes(finalOutputType);
514520
}
515521

516522
/// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -519,25 +525,19 @@ static void buildTransConvOpWithQuantInfo(
519525
OpBuilder &builder, OperationState &result, Type outputType, Value input,
520526
Value weight, Value bias, DenseI64ArrayAttr outpad,
521527
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522-
result.addOperands({input, weight, bias});
528+
auto zps = createZPsAsConst(builder, input, weight);
529+
result.addOperands({input, weight, bias, zps.first, zps.second});
523530
result.addAttribute("out_pad", outpad);
524531
result.addAttribute("stride", stride);
525532
result.addAttribute("out_shape", outputShape);
526533
result.addAttribute("acc_type", accType);
527-
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
528-
534+
Type finalOutputType = outputType;
535+
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
529536
if (quantAttr) {
530-
result.addAttribute("input_zp",
531-
builder.getI32IntegerAttr(
532-
static_cast<int32_t>(quantAttr.getInputZp())));
533-
result.addAttribute("weight_zp",
534-
builder.getI32IntegerAttr(
535-
static_cast<int32_t>(quantAttr.getWeightZp())));
536-
result.addTypes(
537-
buildConvOpResultTypeInfo(builder, outputType, input, weight));
538-
} else {
539-
result.addTypes(outputType);
537+
finalOutputType =
538+
buildConvOpResultTypeInfo(builder, outputType, input, weight);
540539
}
540+
result.addTypes(finalOutputType);
541541
}
542542

543543
/// The tosa.fully_connected op has its own builder as it does not have
@@ -2492,18 +2492,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
24922492
return failure();
24932493
}
24942494

2495-
// Create a rank-0 const tensor for zero point of the source tensor.
2495+
// Create a rank-1 const tensor for zero point of the source tensor.
24962496
std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
24972497
Location loc,
24982498
Type srcElemType,
24992499
int64_t zp) {
2500-
if (auto quantType =
2501-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2502-
srcElemType = quantType.getStorageType();
2503-
2504-
auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
2500+
srcElemType = getElementTypeOrSelf(srcElemType);
25052501
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
25062502
srcElemType = quantType.getStorageType();
2503+
auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
25072504
if (llvm::isa<FloatType>(srcElemType)) {
25082505
auto zpAttr = DenseElementsAttr::get(
25092506
zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));

0 commit comments

Comments
 (0)