Skip to content

Commit

Permalink
feat: Handle empty schemas for unsupported ops
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Apr 8, 2022
1 parent 43a53ce commit bf6c929
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "core/conversion/evaluators/evaluators.h"
#include "core/conversion/var/Var.h"
#include "core/util/prelude.h"

#include <ATen/core/operator_name.h>
#include "c10/util/intrusive_ptr.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
Expand Down Expand Up @@ -491,11 +491,20 @@ std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(cons
auto schema = n->maybeSchema();
// Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema but they are supported.
// torch::jit::prim::DictConstruct is supported via fallback only
if (schema && !OpSupported(n)) {
std::stringstream ss;
ss << *schema;
unsupported_ops[schema->operator_name()] = ss.str();
if (!OpSupported(n)) {
if (schema){
std::stringstream ss;
ss << *schema;
unsupported_ops[schema->operator_name()] = ss.str();
} else {
std::stringstream ss;
ss << util::node_info(n);
// operator.overload is a filler name just to call the constructor.
c10::OperatorName op(ss.str(), "operator.overload");
unsupported_ops[op] = ss.str();
}
}

for (const auto sub_b : n->blocks()) {
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
Expand Down Expand Up @@ -530,7 +539,7 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
auto unsupported_ops = GetUnsupportedOpsInBlock(b);

LOG_DEBUG("======unsupported_ops size ===========: " << unsupported_ops.size());
if (unsupported_ops.size() != 0) {
std::stringstream unsupported_msg;
unsupported_msg
Expand Down

0 comments on commit bf6c929

Please sign in to comment.