diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 67839fed02a2..ee5f3e1dd43c 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -51,5 +51,4 @@ # Make the disco module optional. disco = None # type: ignore[assignment] -from .support import _regex_match from tvm_ffi import Shape as ShapeTuple diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index f6591b28717e..d9762ef57116 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -17,59 +17,8 @@ """Runtime support infra of TVM.""" -import re from typing import TypeVar -import tvm_ffi - - -@tvm_ffi.register_global_func("tvm.runtime.regex_match") -def _regex_match(regex_pattern: str, match_against: str) -> bool: - """Check if a pattern matches a regular expression - - This function should be used instead of `std::regex` within C++ - call sites, to avoid ABI incompatibilities with pytorch. - - Currently, the pytorch wheels available through pip install use - the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to - user the pre-C++11 ABI, this would cause breakages with - dynamically-linked LLVM environments. - - Use of the `` header in TVM should be avoided, as its - implementation is not supported by gcc's dual ABI. This ABI - incompatibility results in runtime errors either when `std::regex` - is called from TVM, or when `std::regex` is called from pytorch, - depending on which library was loaded first. This restriction can - be removed when a version of pytorch compiled using - `-DUSE_CXX11_ABI=1` is available from PyPI. - - This is exposed as part of `libtvm_runtime.so` as it is used by - the DNNL runtime. - - [0] https://github.com/pytorch/pytorch/issues/51039 - - Parameters - ---------- - regex_pattern: str - - The regular expression - - match_against: str - - The string against which to match the regular expression - - Returns - ------- - match_result: bool - - True if `match_against` matches the pattern defined by - `regex_pattern`, and False otherwise. - - """ - match = re.match(regex_pattern, match_against) - return match is not None - - T = TypeVar("T") diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 02b2c70b7fb3..031a552a00e7 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -32,7 +32,6 @@ #include #include -#include "../../runtime/regex.h" #include "utils.h" namespace tvm { diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 2b07b6f9e554..a6440952cdd2 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -31,7 +31,6 @@ #include #include -#include "../../../runtime/regex.h" #include "../json/json_node.h" #include "../json/json_runtime.h" @@ -49,6 +48,16 @@ namespace contrib { using namespace tvm::runtime; using namespace tvm::runtime::json; +namespace { +inline bool contains(const std::string& s, const std::string& sub) { + return s.find(sub) != std::string::npos; +} +template +inline bool contains_any(const std::string& s, const Args&... args) { + return (contains(s, args) || ...); +} +} // namespace + class DNNLJSONRuntime : public JSONRuntimeBase { public: DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json, @@ -189,46 +198,35 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr; - // Define RegExp. - std::string bias_add_pat(".*_bias.*"); - std::string relu_pat(".*_relu.*"); - std::string tanh_pat(".*_tanh.*"); - std::string sigmoid_pat(".*_sigmoid.*"); - std::string clip_pat(".*_clip.*"); - std::string gelu_pat(".*_gelu.*"); - std::string swish_pat(".*_swish.*"); - std::string sum_pat(".*_sum.*"); - std::string mish_pat(".*_mish.*"); - // parsing of name to extract attributes auto op_name = nodes_[nid].GetOpName(); // Parsing post-ops. dnnl::post_ops ops; - if (tvm::runtime::regex_match(op_name, sum_pat)) { + if (contains(op_name, "_sum")) { ops.append_sum(1.f); } - if (tvm::runtime::regex_match(op_name, relu_pat)) { + if (contains(op_name, "_relu")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); } - if (tvm::runtime::regex_match(op_name, tanh_pat)) { + if (contains(op_name, "_tanh")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); } - if (tvm::runtime::regex_match(op_name, clip_pat)) { + if (contains(op_name, "_clip")) { float a_min = GetNodeAttr(nodes_[nid], "a_min"); float a_max = GetNodeAttr(nodes_[nid], "a_max"); ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max); } - if (tvm::runtime::regex_match(op_name, sigmoid_pat)) { + if (contains(op_name, "_sigmoid")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); } - if (tvm::runtime::regex_match(op_name, swish_pat)) { + if (contains(op_name, "_swish")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f); } - if (tvm::runtime::regex_match(op_name, gelu_pat)) { + if (contains(op_name, "_gelu")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } - if (tvm::runtime::regex_match(op_name, mish_pat)) { + if (contains(op_name, "_mish")) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f); } if (ops.len() != 0) { @@ -236,8 +234,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Parsing bias_add. - *bias_tr = - tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{}; + *bias_tr = contains(op_name, "_bias") ? GetInput(nid, 2) : TensorRequisite{}; return attr; } @@ -250,31 +247,24 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::set io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end()); tensor_registry_ = TensorRegistry(engine_, io_eid_set); - std::string conv_pat(".*conv[1-3]d.*"); - std::string deconv_pat(".*deconv[1-3]d.*"); - std::string conv_transpose_pat(".*conv[1-3]d_transpose.*"); - std::string dense_pat(".*dense.*"); - std::string max_pool_pat(".*max_pool[1-3]d"); - std::string avg_pool_pat(".*avg_pool[1-3]d"); - // Build subgraph engine. for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { TVM_FFI_ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); - if (tvm::runtime::regex_match(op_name, deconv_pat) || - tvm::runtime::regex_match(op_name, conv_transpose_pat)) { + if (contains_any(op_name, "deconv1d", "deconv2d", "deconv3d", "conv1d_transpose", + "conv2d_transpose", "conv3d_transpose")) { Deconvolution(nid); - } else if (tvm::runtime::regex_match(op_name, conv_pat)) { + } else if (contains_any(op_name, "conv1d", "conv2d", "conv3d")) { Convolution(nid); - } else if (tvm::runtime::regex_match(op_name, dense_pat)) { + } else if (contains(op_name, "dense")) { Dense(nid); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); - } else if (tvm::runtime::regex_match(op_name, max_pool_pat)) { + } else if (contains_any(op_name, "max_pool1d", "max_pool2d", "max_pool3d")) { Pooling(nid, dnnl::algorithm::pooling_max); - } else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) { + } else if (contains_any(op_name, "avg_pool1d", "avg_pool2d", "avg_pool3d")) { Pooling(nid, dnnl::algorithm::pooling_avg); } else if (elt_name2algo.count(op_name)) { Eltwise(nid); diff --git a/src/runtime/regex.cc b/src/runtime/regex.cc deleted file mode 100644 index a91bf479ce4b..000000000000 --- a/src/runtime/regex.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/regex.cc - * \brief Exposes calls to python's `re` library. - */ - -#include "./regex.h" - -#include - -namespace tvm { -namespace runtime { - -bool regex_match(const std::string& match_against, const std::string& regex_pattern) { - const auto regex_match_func = tvm::ffi::Function::GetGlobal("tvm.runtime.regex_match"); - if (!regex_match_func.has_value()) { - TVM_FFI_THROW(RuntimeError) - << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " - << "This can occur if the TVM Python library has not yet been imported."; - } - return (*regex_match_func)(regex_pattern, match_against).cast(); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/regex.h b/src/runtime/regex.h deleted file mode 100644 index d8a62e72d387..000000000000 --- a/src/runtime/regex.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file regex.h - * \brief Exposes calls to python's `re` library. - */ -#ifndef TVM_RUNTIME_REGEX_H_ -#define TVM_RUNTIME_REGEX_H_ - -#include - -#include - -namespace tvm { -namespace runtime { - -/* \brief Check if a pattern matches a regular expression - * - * This function should be used instead of `std::regex` within C++ - * call sites, to avoid ABI incompatibilities with pytorch. - * - * Currently, the pytorch wheels available through pip install use - * the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to - * user the pre-C++11 ABI, this would cause breakages with - * dynamically-linked LLVM environments. - * - * Use of the `` header in TVM should be avoided, as its - * implementation is not supported by gcc's dual ABI. This ABI - * incompatibility results in runtime errors either when `std::regex` - * is called from TVM, or when `std::regex` is called from pytorch, - * depending on which library was loaded first. This restriction can - * be removed when a version of pytorch compiled using - * `-DUSE_CXX11_ABI=1` is available from PyPI. - * - * [0] https://github.com/pytorch/pytorch/issues/51039 - * - * \param match_against The string against which to match the regular expression - * - * \param regex_pattern The regular expression - * - * \returns match_result True if `match_against` matches the pattern - * defined by `regex_pattern`, and False otherwise. - */ - -TVM_RUNTIME_DLL bool regex_match(const std::string& match_against, - const std::string& regex_pattern); - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_REGEX_H_