Skip to content

Commit

Permalink
Qnn fully connected (#3910)
Browse files Browse the repository at this point in the history
* Qnn Dense layer.

* Reformatting code.

* Reformatting code and making the test case more readable.

* Fixing lint issues.

* Fixing test method names to pass the nose related configurations.

* Aligning the code for code style.
  • Loading branch information
shoubhik authored and zhiics committed Sep 22, 2019
1 parent 16d4da4 commit 43f54a5
Show file tree
Hide file tree
Showing 9 changed files with 443 additions and 25 deletions.
32 changes: 20 additions & 12 deletions include/tvm/relay/qnn/attrs.h
Expand Up @@ -74,10 +74,8 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, can be one of [int8 or uint8].");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
Expand All @@ -91,7 +89,6 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero_point for the input tensor of this op.");

TVM_ATTR_FIELD(input_scale)
.describe("The scale for the input tensor of this op.");
}
Expand All @@ -108,16 +105,12 @@ struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
TVM_ATTR_FIELD(input_scales)
.describe("The list of scales of input quantized tensors.");

TVM_ATTR_FIELD(input_zero_points)
.describe("The list of zero points of input quantized tensors.");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the output tensor.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the output tensor.");

TVM_ATTR_FIELD(axis)
.describe("The axis at which the input arrays are concatenated."
"Should lie in range `[-ndim, ndim)`.")
Expand Down Expand Up @@ -199,24 +192,39 @@ struct QnnBinaryOpAttrs : public tvm::AttrsNode<QnnBinaryOpAttrs> {
TVM_DECLARE_ATTRS(QnnBinaryOpAttrs, "relay.attrs.QnnBinaryOpAttrs") {
TVM_ATTR_FIELD(lhs_zero_point)
.describe("The zero_point for the lhs input tensor of this op.");

TVM_ATTR_FIELD(lhs_scale)
.describe("The scale for the lhs input tensor of this op.");

TVM_ATTR_FIELD(rhs_zero_point)
.describe("The zero_point for the rhs input tensor of this op.");

TVM_ATTR_FIELD(rhs_scale)
.describe("The scale for the rhs input tensor of this op.");

TVM_ATTR_FIELD(output_zero_point)
.describe("The zero_point for the activation of this op.");

TVM_ATTR_FIELD(output_scale)
.describe("The scale for the activation of this op.");
}
};

/*! \brief Attributes for qnn dense operator */
struct QnnDenseAttrs : public tvm::AttrsNode<QnnDenseAttrs> {
IndexExpr units;
DataType out_dtype;
// Quantization related attributes.
int32_t input_zero_point;
int32_t kernel_zero_point;

TVM_DECLARE_ATTRS(QnnDenseAttrs, "relay.attrs.qnn.QnnDenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");
TVM_ATTR_FIELD(out_dtype)
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(input_zero_point)
.describe("The zero point of the input tensor.");
TVM_ATTR_FIELD(kernel_zero_point)
.describe("The zero point of the kernel tensor.");
}
};

} // namespace qnn
} // namespace relay
} // namespace tvm
Expand Down
48 changes: 46 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Expand Up @@ -96,7 +96,7 @@ def quantize(data,
The output zero_point.
output_scale : float
The output scale.
input_dtype : str, optional
out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
Returns
-------
Expand Down Expand Up @@ -265,7 +265,13 @@ def conv2d(data,
data_layout, kernel_layout, out_layout, out_dtype)


def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
def add(lhs,
rhs,
lhs_scale,
lhs_zero_point,
rhs_scale,
rhs_zero_point,
output_scale,
output_zero_point):
"""Quantized addition with numpy-style broadcasting.
Expand Down Expand Up @@ -305,3 +311,41 @@ def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_s
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)


def quantized_dense(data,
weight,
input_zero_point,
kernel_zero_point,
units=None,
out_dtype="int32"):
"""Qnn Dense operator.
Applies a quantized linear transformation
.. math::
`Y = X * W`
Parameters
----------
data : tvm.relay.Expr
The quantized input data to the operator.
weight : tvm.relay.Expr
The quantized weight expressions.
units : int, optional
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense can be int32 or int16.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.dense(data,
weight,
units,
input_zero_point,
kernel_zero_point,
out_dtype)
4 changes: 2 additions & 2 deletions python/tvm/relay/qnn/transform.py
Expand Up @@ -22,7 +22,7 @@

def CanonicalizeOps():
"""Converts/Lowers an expression containing QNN ops to an expression containing only core
(non-Dialect) Relay ops. Each QNN op is lowered to a sequence of exisiting Relay ops. This is a
(non-Dialect) Relay ops. Each QNN op is lowered to a sequence of existing Relay ops. This is a
target-independent pass. One can register the lowering/transformation function for this op using
FTVMQnnCanonicalize attr_name for FTVMLegalize op attribute. An example of this transformation
is below
Expand All @@ -40,7 +40,7 @@ def CanonicalizeOps():
output_zero_point=0,
out_dtype='int8')
# We want to utilize all the existing Relay infrastucture. So, instead of supporting this
# We want to utilize all the existing Relay infrastructure. So, instead of supporting this
# QNN requantize op, we convert it into a sequence of existing Relay operators.
mod = relay.Module.from_expr(qnn_expr)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/nn/convolution.h
Expand Up @@ -25,6 +25,7 @@
#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
#define TVM_RELAY_OP_NN_CONVOLUTION_H_

#include <tvm/ir_pass.h>
#include <string>
#include <utility>

Expand Down
11 changes: 11 additions & 0 deletions src/relay/pass/pattern_util.h
Expand Up @@ -434,6 +434,17 @@ static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}

static inline Expr Dense(Expr data,
Expr weight,
IndexExpr units,
DataType out_dtype) {
auto attrs = make_node<DenseAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}

static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_node<ReduceAttrs>();
attrs->axis = std::move(axis);
Expand Down
15 changes: 7 additions & 8 deletions src/relay/qnn/op/convolution.cc
Expand Up @@ -23,7 +23,6 @@
* \brief Property def of qnn convolution operator.
*/
#include <tvm/data_layout.h>
#include <tvm/ir_pass.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/base.h>
#include <tvm/relay/op.h>
Expand Down Expand Up @@ -178,7 +177,7 @@ Expr Conv2DPadInput(const Expr& data, const QnnConv2DAttrs* param) {
* \param data The input expr.
* \param weight The weight expr.
* \param param The qnn conv2d attributes.
* \return The sequence of Relay operatos for term1.
* \return The sequence of Relay operators for term1.
* \note The term1 is
* Sigma(c,r,s) QW(k, c, r, s) * QA(n, c, h + r, w + s)
* This is just conv2d on int tensors.
Expand All @@ -198,12 +197,12 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
* \param param The qnn conv2d attributes.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operatos for term2.
* \return The sequence of Relay operators for term2.
* \note The term2 looks like this
*
* Sigma(c,r,s) zp_w * QA(n, c, h + r, w + s)
*
* Second term is not directly represetable by one Relay operator.
* Second term is not directly representable by one Relay operator.
* However, deeper analysis shows that we can reduce r,s using avg_pool2d,
* followed by a reduce on the C axis. Using avg_pool2d also gives an
* opportunity to reuse alter_op_layout infrastructure.
Expand Down Expand Up @@ -313,7 +312,7 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAtt
* \param in_channels The number of input channels.
* \param kernel_h The height of kernel.
* \param kernel_w The width of kernel.
* \return The sequence of Relay operatos for term4.
* \return The sequence of Relay operators for term4.
* \note The term4 looks like this
*
* Sigma(c,r,s) zp_a * zp_w
Expand Down Expand Up @@ -373,7 +372,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* where QA is quantized tensor, scale_a and zp_A are quantizations
* params.
*
* Quantized convlution convolves two quantized tensors and returns a
* Quantized convolution will convolve two quantized tensors and returns a
* quantized tensor of default dtype of int32, with scale equaling to the
* product of scales of input tensors, and a zero point of zero.
*
Expand All @@ -399,7 +398,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* zero point. This might leave some performance opportunity at the
* table. Can be avoided by modifying conv2d API to accept the
* pad_const_value.
* 2) Second term is not directly represetable by one Relay operator.
* 2) Second term is not directly representable by one Relay operator.
* However, deeper analysis shows that we can reduce r,s using
* avg_pool2d, followed by a reduce on the C axis. Using avg_pool2d also
* gives an opportunity to reuse alter_op_layout infrastructure.
Expand All @@ -408,7 +407,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* the conv is dilated. We fallback also in case of depthwise conv.
*
* The whole process can be broken down into following steps
* * Assertion checks for exisiting support, fallback if necessary
* * Assertion checks for existing support, fallback if necessary
* * Pad the input.
* * Get Term1.
* * Get Term2.
Expand Down
131 changes: 131 additions & 0 deletions src/relay/qnn/op/dense.cc
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/op/dense.cc
* \brief Property def of qnn dense operator.
*/

#include <tvm/relay/base.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../op/nn/nn.h"
#include "../../pass/pattern_util.h"

namespace tvm {
namespace relay {
namespace qnn {

// relay.op.qnn.dense
TVM_REGISTER_NODE_TYPE(QnnDenseAttrs);

bool QnnDenseRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<QnnDenseAttrs>();
CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr.";
CHECK(data->dtype == Int(8) || data->dtype == UInt(8))
<< "Expected quantized dense type(int8, uint8) for input but was " << data->dtype;
CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " << weight->dtype;
CHECK(param->out_dtype == Int(32))
<< "Expected quantized dense type(int32) for output but was " << param->out_dtype;
CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";
return DenseRel<QnnDenseAttrs>(types, num_inputs, attrs, reporter);
}

// Positional relay function to create quantized dense operator used by frontend FFI.
Expr MakeQuantizedDense(Expr data,
Expr weight,
IndexExpr units,
int32_t input_zero_point,
int32_t kernel_zero_point,
DataType out_dtype) {
auto attrs = make_node<QnnDenseAttrs>();
attrs->units = std::move(units);
attrs->out_dtype = out_dtype;
attrs->input_zero_point = input_zero_point;
attrs->kernel_zero_point = kernel_zero_point;
static const Op& op = Op::Get("qnn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}

/**
* \brief Lowers Qnn convolution in terms of core operators in relay.
* Mathematically it is equals to -
* Dense((quantized_input - input_zero_point;int32), (quantized_kernel - kernel_zero_point; int32))
*
* \param attrs QnnDenseAttrs for Qnn Dense layer.
* \param new_args The new mutated args to the call node.
* \param arg_types The data types of input and output.
* \reutrn The sequence of Relay ops for qnn cov2d op.
*/
Expr QnnDenseCanonicalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
CHECK_EQ(new_args.size(), 2);
Expr quantized_data = new_args[0];
Expr quantized_kernel = new_args[1];
const auto* qnn_dense_attrs = attrs.as<QnnDenseAttrs>();
Expr quantized_data_int32 = Cast(quantized_data, Int(32));
if (qnn_dense_attrs->input_zero_point != 0) {
quantized_data_int32 = Subtract(quantized_data_int32,
MakeConstantScalar(Int(32),
qnn_dense_attrs->input_zero_point));
}
Expr quantized_kernel_int32 = Cast(quantized_kernel, Int(32));
if (qnn_dense_attrs->kernel_zero_point != 0) {
quantized_kernel_int32 = Subtract(quantized_kernel_int32,
MakeConstantScalar(Int(32),
qnn_dense_attrs->kernel_zero_point));
}
Expr int32_dense = Dense(quantized_data_int32,
quantized_kernel_int32,
qnn_dense_attrs->units,
qnn_dense_attrs->out_dtype);
return int32_dense;
}

RELAY_REGISTER_OP("qnn.dense")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)`
- **weight**: quantized(int8, unit8) `(units, input_dim)`
- **out**: quantized(int32) `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.qnn.QnnDenseAttrs")
.set_num_inputs(2)
.add_argument("data", "quantized nD Tensor", "Input data.")
.add_argument("weight", "quantized 2D Tensor", "Weight matrix.")
.set_support_level(11)
.add_type_rel("QDense", DenseRel<QnnDenseAttrs>)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);

TVM_REGISTER_API("relay.qnn.op._make.dense")
.set_body_typed(MakeQuantizedDense);

} // namespace qnn
} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_conv2d.py
Expand Up @@ -19,7 +19,6 @@
import numpy as np
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import create_workload
from tvm.relay.testing import run_infer_type
from tvm.contrib import graph_runtime

Expand Down

0 comments on commit 43f54a5

Please sign in to comment.