Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/tvm/relay/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,25 @@ def structural_hash(value):
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)


def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions.

The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
and ExtractFusedFunctions in that order

Parameters
----------
mod : tvm.relay.IRModule

Returns
-------
ret : Dict[int, tvm.relay.expr.Function]
A module containing only fused primitive functions
"""
ret_mod = _analysis.ExtractFusedFunctions()(mod)
ret = {}
for hash_, func in ret_mod.functions.items():
ret[hash_] = func
return ret
82 changes: 82 additions & 0 deletions src/relay/analysis/extract_fused_functions.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

class FusedFunctionExtractorWrapper : private ExprVisitor {
public:
explicit FusedFunctionExtractorWrapper(const IRModule& mod) : mod_(mod) {}

IRModule Extract() {
VisitExpr(this->mod_->Lookup("main"));

auto functions = Map<GlobalVar, BaseFunc>();
for (auto pair : this->functions) {
functions.Set(GlobalVar(pair.first), pair.second);
}

this->mod_->functions = functions;
return this->mod_;
}

private:
const IRModule mod_;
// This is not simply Map<GlobalVar, Function> because GlobalVar doesn't
// have the desired equals property
Map<std::string, Function> functions;

void VisitExpr_(const FunctionNode* n) final {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func);
}

ExprVisitor::VisitExpr_(n);
}
};

namespace transform {

Pass ExtractFusedFunctions() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return FusedFunctionExtractorWrapper(m).Extract(); };
auto fused_function_extractor_pass = CreateModulePass(pass_func, 1, "ExtractFusedFunctions", {});

return Sequential({SimplifyInference(), FuseOps(3), fused_function_extractor_pass},
"ExtractFusedFunctions");
}

TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);

} // namespace transform

} // namespace relay
} // namespace tvm
115 changes: 115 additions & 0 deletions tests/python/relay/test_analysis_extract_fused_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test function extraction"""
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload


def get_conv_net():
"""This gets the net for a case described in fuse_ops.cc:

conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
"""
dshape = (1, 1, 5, 1)
x = relay.var("x", shape=dshape)
y = relay.nn.conv2d(x, relay.var("w1"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)

x1 = relay.nn.conv2d(y, relay.var("w2"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
x2 = relay.nn.conv2d(y, relay.var("w3"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)
x3 = relay.nn.conv2d(y, relay.var("w4"),
kernel_size=(3, 3),
padding=(1, 1),
channels=1)

z = relay.add(x1, x2)
z = relay.add(x3, z)

return tvm.IRModule.from_expr(z)


def get_conv2d():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1', shape=(3, 3, 64, 32))
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC',
kernel_layout='HWIO')
return tvm.IRModule.from_expr(y)


def test_extract_identity():
mod = get_conv2d()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 1

mod["main"] = mod["main"].with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])


def test_extract_conv_net():
mod = get_conv_net()
items = relay.analysis.extract_fused_functions(mod)
functions = list(items.values())
assert len(functions) == 2
x = functions[0]
y = functions[1]

def is_conv(func):
conv2d = relay.op.op.get("nn.conv2d")
call_node = func.body
return call_node.op == conv2d

def is_conv_add(func):
add = relay.op.op.get("add")
call_node = func.body
maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0])
return call_node.op == add and is_conv(maybe_conv_module["main"])

# Function traversal order isn't obvious, so checking both orders is more consistent
assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y))


def test_extract_resnet():
mod, _params = get_workload()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 34


if __name__ == '__main__':
test_extract_identity()
test_extract_conv_net()
test_extract_resnet()