@@ -510,7 +510,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
510
510
result.addAttribute (" stride" , stride);
511
511
result.addAttribute (" dilation" , dilation);
512
512
result.addAttribute (" acc_type" , accType);
513
- result.addTypes (outputType);
513
+ Type finalOutputType = outputType;
514
+ auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
515
+ if (quantAttr) {
516
+ finalOutputType =
517
+ buildConvOpResultTypeInfo (builder, outputType, input, weight);
518
+ }
519
+ result.addTypes (finalOutputType);
514
520
}
515
521
516
522
// / Handles tosa.transpose_conv2d which has outpad and output shape
@@ -519,25 +525,19 @@ static void buildTransConvOpWithQuantInfo(
519
525
OpBuilder &builder, OperationState &result, Type outputType, Value input,
520
526
Value weight, Value bias, DenseI64ArrayAttr outpad,
521
527
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522
- result.addOperands ({input, weight, bias});
528
+ auto zps = createZPsAsConst (builder, input, weight);
529
+ result.addOperands ({input, weight, bias, zps.first , zps.second });
523
530
result.addAttribute (" out_pad" , outpad);
524
531
result.addAttribute (" stride" , stride);
525
532
result.addAttribute (" out_shape" , outputShape);
526
533
result.addAttribute (" acc_type" , accType);
527
- auto quantAttr = :: buildConvOpQuantizationAttr (builder, input, weight) ;
528
-
534
+ Type finalOutputType = outputType ;
535
+ auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
529
536
if (quantAttr) {
530
- result.addAttribute (" input_zp" ,
531
- builder.getI32IntegerAttr (
532
- static_cast <int32_t >(quantAttr.getInputZp ())));
533
- result.addAttribute (" weight_zp" ,
534
- builder.getI32IntegerAttr (
535
- static_cast <int32_t >(quantAttr.getWeightZp ())));
536
- result.addTypes (
537
- buildConvOpResultTypeInfo (builder, outputType, input, weight));
538
- } else {
539
- result.addTypes (outputType);
537
+ finalOutputType =
538
+ buildConvOpResultTypeInfo (builder, outputType, input, weight);
540
539
}
540
+ result.addTypes (finalOutputType);
541
541
}
542
542
543
543
// / The tosa.fully_connected op has its own builder as it does not have
@@ -2492,18 +2492,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
2492
2492
return failure ();
2493
2493
}
2494
2494
2495
- // Create a rank-0 const tensor for zero point of the source tensor.
2495
+ // Create a rank-1 const tensor for zero point of the source tensor.
2496
2496
std::optional<Value> mlir::tosa::createZeroPointTensor (OpBuilder &builder,
2497
2497
Location loc,
2498
2498
Type srcElemType,
2499
2499
int64_t zp) {
2500
- if (auto quantType =
2501
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2502
- srcElemType = quantType.getStorageType ();
2503
-
2504
- auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2500
+ srcElemType = getElementTypeOrSelf (srcElemType);
2505
2501
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
2506
2502
srcElemType = quantType.getStorageType ();
2503
+ auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2507
2504
if (llvm::isa<FloatType>(srcElemType)) {
2508
2505
auto zpAttr = DenseElementsAttr::get (
2509
2506
zpType, builder.getFloatAttr (srcElemType, static_cast <double >(zp)));
0 commit comments