Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC-DNNL] add support for more ops and fusion patterns #9995

Merged
merged 1 commit into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,39 @@ def _func_wrapper(expr):


_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv1d")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.conv3d")
_register_external_op_helper("nn.conv2d_transpose")
_register_external_op_helper("nn.conv3d_transpose")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.max_pool2d")
_register_external_op_helper("nn.avg_pool2d")
_register_external_op_helper("nn.max_pool3d")
_register_external_op_helper("nn.avg_pool3d")
_register_external_op_helper("abs")
_register_external_op_helper("clip")
_register_external_op_helper("exp")
_register_external_op_helper("log")
_register_external_op_helper("sqrt")
_register_external_op_helper("round")
_register_external_op_helper("logsumexp")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.leaky_relu")
_register_external_op_helper("tanh")
_register_external_op_helper("sigmoid")
_register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")


def make_conv_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.conv2d.
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
"""Create patterns related to conv and deconv.

Parameters
----------
with_bias : bool
Whether attach `bias_add` to `nn.conv2d`.
Whether attach `bias_add` to `conv / deconv`.
masahi marked this conversation as resolved.
Show resolved Hide resolved
with_eltwise : str
The attached elementwise post-op name.
Returns
Expand All @@ -93,7 +110,7 @@ def make_conv_pattern(with_bias=True, with_eltwise=None):
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op("nn.conv2d")(data, weight)
conv = is_op(conv_name)(data, weight)
if with_bias:
conv_out = is_op("add")(conv, bias)
else:
Expand Down Expand Up @@ -146,15 +163,19 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
pat_name = "dnnl." + op
pat_name = op.replace("nn", "dnnl")
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
if op == "conv2d":
dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
elif op == "dense":
if "conv" in op:
dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise))
elif op == "nn.dense":
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
logger.warning(
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
"dense op are supported, but got %s.",
op,
)
dnnl_pattern = ()
return dnnl_pattern

Expand All @@ -174,8 +195,15 @@ def pattern_table():
for elt in elt_list:
if not with_bias and not elt:
return dnnl_patterns
dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
for conv_name in [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
]:
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
return dnnl_patterns


Expand Down
53 changes: 36 additions & 17 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../utils.h"
Expand Down Expand Up @@ -439,6 +440,23 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;

std::map<std::string, std::string> op_map{
{"bias", "add"},
{"relu", "nn.relu"},
{"tanh", "tanh"},
{"sigmoid", "sigmoid"},
};

std::vector<std::string> ParsingOpList(std::string op, std::string pattern_name) {
std::vector<std::string> op_list = {"nn." + op};
for (auto& t : op_map) {
if (pattern_name.find(t.first) != std::string::npos) {
op_list.push_back(t.second);
}
}
return op_list;
}

public:
DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}

Expand All @@ -453,28 +471,29 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
name = comp.value();

if (name == "dnnl.conv2d_bias_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
} else if (name == "dnnl.conv2d_bias_tanh") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_bias_sigmoid") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
if (name.find("dnnl.conv2d_transpose") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv2d_transpose", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_bias") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
} else if (name.find("dnnl.conv3d_transpose") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv3d_transpose", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
} else if (name.find("dnnl.conv1d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv1d", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_tanh") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
} else if (name.find("dnnl.conv2d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv2d", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_sigmoid") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
} else if (name.find("dnnl.conv3d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv3d", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.dense_bias") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
} else if (name.find("dnnl.dense") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("dense", name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
Expand Down
Loading