diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 39212d8fd1e5..7fd9bee9d893 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -228,6 +228,17 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); */ TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3); +/*! + * \brief Combine parallel batch_matmul ops into a single batch_matmul + * if the number of branches of this dense operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3); + /*! * \brief Backward fold axis scaling into weights of conv/dense operators. * diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 8f4ec1046500..a490d6f00a71 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -58,6 +58,7 @@ def build_config(opt_level=2, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, "CombineParallelDense": 4, + "CombineParallelBatchMatmul": 4, "FastMath": 4 } @@ -307,6 +308,39 @@ def CombineParallelDense(min_num_branches=3): """ return _ffi_api.CombineParallelDense(min_num_branches) +def CombineParallelBatchMatmul(min_num_branches=3): + """Combine multiple batch matmul operators into one. For example: + + .. code-block + data (1, 2, 3) + / \ + batch_matmul(data, (1, 4, 3)) batch_matmul(data, (1, 5, 3)) + | | + elemwise/bcast (1, 2, 4) elemwise/bcast (1, 2, 5) + + Would become: + + .. code-block + + data (1, 2, 3) + | + batch_matmul(data, (1, 4+5, 3)) + | + elemwise/bcast (1 ,2, 4+5) + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.transform.Pass + The registered pass that combines parallel dense operators. + """ + return _ffi_api.CombineParallelBatchMatmul(min_num_branches) + def AlterOpLayout(): """Alternate the layouts of operators or replace primitive operators with diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f9ce24d410b7..69ad8785dbca 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -277,6 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc new file mode 100644 index 000000000000..1529631d5ec1 --- /dev/null +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file combine_parallel_batch_matmul.cc + * \brief Combine parallel batch matmuls into a single one. + * + * This pass replaces batch_matmul that share the same lhs node with a + * single batch matmul.Elemwise and broadcast ops following batch_matmul are also + * combined if possible. + * + * This prevents launching multiple kernels in networks with multiple + * convolution branches, such as Inception block. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "./combine_parallel_op.h" +#include "./expr_subst.h" +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class ParallelBatchMatmulCombiner : public ParallelOpCombiner { + public: + explicit ParallelBatchMatmulCombiner(uint64_t min_num_branches) + : ParallelOpCombiner("nn.batch_matmul", min_num_branches) {} + + protected: + bool IsSupportedOp(const CallNode* n) { return true; } + + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { + StructuralEqual eq; + const auto* rhs_a = a->args[1]->type_as(); + const auto* rhs_b = b->args[1]->type_as(); + const auto* restype_a = a->type_as(); + const auto* restype_b = b->type_as(); + // shape[2] is the contraction axis and automatically consistent + // if it were valid batch_matmul ops + auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && + (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && + eq(rhs_a->shape[0], rhs_b->shape[0]); + return res; + } + + Call MakeCombinedOp(const Group& branches) { + const Op& batch_matmul = Op::Get("nn.batch_matmul"); + Expr data = branches[0][0]->args[0]; + + Array weights; + for (const auto& branch : branches) { + auto batch_matmul = branch[0]; + weights.push_back(batch_matmul->args[1]); + } + Expr new_weight = MakeConcatenate(Tuple(weights), 1); + return Call(batch_matmul, {data, new_weight}, {}, {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } + + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + + Array tuple; + for (const auto& branch : branches) { + tuple.push_back(branch[depth]->args[i]); + } + + auto concat = MakeConcatenate(Tuple(tuple), -1); + new_args.push_back(std::move(concat)); + } + + return Call(call->op, new_args, call->attrs, {}); + } + + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, + ExprSubstMap* subst_map) { + int64_t index = 0; + + for (const auto& branch : branches) { + const CallNode* batch_matmul = branch[0]; + auto feature_dim = batch_matmul->args[1]->type_as()->shape[1]; + auto fpp = tir::as_const_int(feature_dim); + int64_t features = *fpp; + std::vector begin; + std::vector end; + for (size_t i = 0; i < 2; i++) { + begin.push_back(0); + end.push_back(-1); + } + begin.push_back(index); + index += features; + end.push_back(features); + std::vector strides(begin.size(), 1); + std::vector ndarray_shape = {static_cast(begin.size())}; + Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); + Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); + Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); + auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); + subst_map->insert({GetRef(branch[depth]), slice}); + } + } +}; + +/*! \brief Combine parallel batch_matmul if number of branches >= min_num_branches */ +Expr CombineParallelBatchMatmul(const Expr& expr, uint64_t min_num_branches) { + return ParallelBatchMatmulCombiner(min_num_branches).Combine(expr); +} + +namespace transform { + +Pass CombineParallelBatchMatmul(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelBatchMatmul(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelBatchMatmul", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelBatchMatmul") + .set_body_typed(CombineParallelBatchMatmul); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_combine_parallel_batch_matmul.py b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py new file mode 100644 index 000000000000..00d8ac40a129 --- /dev/null +++ b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals,too-many-arguments,missing-module-docstring + +import tvm +from tvm import relay +from tvm.relay import transform + + +def run_opt_pass(expr, opt_pass): + "runs the opt_pass on the expr of a function the function" + assert isinstance(opt_pass, tvm.transform.Pass) + mod = tvm.IRModule.from_expr(expr) + mod = opt_pass(mod) + return mod["main"] + +def test_combine_parallel_batch_matmul(): + """Simple testcase.""" + def before(x, w1, w2, w3): + args = [x, w1, w2, w3] + y1 = relay.nn.batch_matmul(x, w1) + y2 = relay.nn.batch_matmul(x, w2) + y3 = relay.nn.batch_matmul(x, w3) + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3): + # use a fixed order of args so alpha equal check can pass + s1 = w1.type_annotation.shape[1] + s2 = w2.type_annotation.shape[1] + s3 = w3.type_annotation.shape[1] + args = [x, w1, w2, w3] + w = relay.concatenate((w1, w2, w3), axis=1) + y = relay.nn.batch_matmul(x, w) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0, 0], "int64"), + end=relay.const([-1, -1, s1], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, 0, s1], "int64"), + end=relay.const([-1, -1, s2], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y3 = relay.strided_slice(y, + begin=relay.const([0, 0, s1+s2], "int64"), + end=relay.const([-1, -1, s3], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def check(b, i, j, k): + x = relay.var("x", shape=(b, i, k)) + w1 = relay.var("w1", shape=(b, j, k)) + w2 = relay.var("w2", shape=(b, j, k)) + w3 = relay.var("w3", shape=(b, j, k)) + + y_before = before(x, w1, w2, w3) + y = run_opt_pass(y_before, + transform.CombineParallelBatchMatmul(min_num_branches=2)) + y_expected = expected(x, w1, w2, w3) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + check(2, 3, 5, 4) + check(1, 100, 200, 300) + +def test_combine_parallel_batch_matmul_biasadd(): + """Simple testcase with bias""" + def before(x, w1, w2, w3, b1, b2, b3): + args = [x, w1, w2, w3, b1, b2, b3] + y1 = relay.nn.batch_matmul(x, w1) + y2 = relay.nn.batch_matmul(x, w2) + y3 = relay.nn.batch_matmul(x, w3) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) + y3 = relay.add(y3, b3) + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, b1, b2, b3): + # use a fixed order of args so alpha equal check can pass + s1 = w1.type_annotation.shape[1] + s2 = w2.type_annotation.shape[1] + s3 = w3.type_annotation.shape[1] + args = [x, w1, w2, w3, b1, b2, b3] + w = relay.concatenate((w1, w2, w3), axis=1) + b = relay.concatenate((b1, b2, b3), axis=-1) + y = relay.nn.batch_matmul(x, w) + y = relay.add(y, b) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0, 0], "int64"), + end=relay.const([-1, -1, s1], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, 0, s1], "int64"), + end=relay.const([-1, -1, s2], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y3 = relay.strided_slice(y, + begin=relay.const([0, 0, s1+s2], "int64"), + end=relay.const([-1, -1, s3], "int64"), + strides=relay.const([1, 1, 1], 'int64'), + slice_mode="size") + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def check(b, i, j, k): + x = relay.var("x", shape=(b, i, k)) + w1 = relay.var("w1", shape=(b, j, k)) + w2 = relay.var("w2", shape=(b, j, k)) + w3 = relay.var("w3", shape=(b, j, k)) + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(j,)) + b3 = relay.var("b3", shape=(j,)) + + y_before = before(x, w1, w2, w3, b1, b2, b3) + y = run_opt_pass(y_before, + transform.CombineParallelBatchMatmul(min_num_branches=2)) + y_expected = expected(x, w1, w2, w3, b1, b2, b3) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + check(2, 3, 5, 4) + check(1, 100, 200, 300) + + +if __name__ == "__main__": + test_combine_parallel_batch_matmul() + test_combine_parallel_batch_matmul_biasadd()