diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp index d9c07160bdb..9ad09366ab8 100644 --- a/oneflow/core/common/error.cpp +++ b/oneflow/core/common/error.cpp @@ -274,7 +274,7 @@ Error Error::InputDeviceNotMatchError() { auto error = std::make_shared(); auto* input_device_not_match_error = error->mutable_input_device_not_match_error(); input_device_not_match_error->add_info( - std::string("The devices of input tensors are inconsistent,please try to use tensor.to or " + std::string("The devices of input tensors are inconsistent, please try to use tensor.to or " "module.to to correct it.")); return error; } diff --git a/oneflow/core/common/error_util.cpp b/oneflow/core/common/error_util.cpp index ce173dd8132..180f423b212 100644 --- a/oneflow/core/common/error_util.cpp +++ b/oneflow/core/common/error_util.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "oneflow/core/common/error_util.h" +#include "oneflow/core/common/util.h" namespace oneflow { @@ -24,8 +25,8 @@ std::string StripSpace(std::string str) { if (str.size() == 0) { return ""; } size_t pos = str.find_first_not_of(" "); if (pos != std::string::npos) { str.erase(0, pos); } - pos = str.find_last_not_of(" ") + 1; - if (pos != std::string::npos) { str.erase(pos); } + pos = str.find_last_not_of(" "); + if (pos != std::string::npos) { str.erase(pos + 1); } return str; } @@ -41,66 +42,71 @@ std::string StripBrackets(std::string str) { return str; } -Maybe ShortenErrorMsg(std::string str) { +Maybe ShortenMsg(std::string str) { // 150 characters is the threshold - const int num_displayed_char = 150; + const int num_character_threshold = 150; + const int num_displayed_character = 50; if (str.size() == 0) { return str; } // strip space when JUST( xx ); str = StripSpace(str); - if (str.size() < num_displayed_char) { return str; } - - // Find first index where the number of characters from the start to the index is less than 50, - // last index is the same - int first_index = -1; - int last_index = -1; - int pre_index = 0; - CHECK_OR_RETURN(str.size() >= 1); - for (int i = 1; i < str.size(); i++) { - if (IsLetterNumberOrUnderline(str.at(i)) && !IsLetterNumberOrUnderline(str.at(i - 1))) { - if (first_index == -1 && i >= num_displayed_char / 3) { first_index = pre_index; } - if (last_index == -1 && str.size() - i <= num_displayed_char / 3) { last_index = i; } - pre_index = i; + if (str.size() < num_character_threshold) { return str; } + + // left part whose number of characters is just over 50 + int left_index = num_displayed_character; + bool pre_condition = IsLetterNumberOrUnderline(str.at(left_index)); + for (; left_index < str.size(); left_index++) { + bool cur_condition = IsLetterNumberOrUnderline(str.at(left_index)); + if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { break; } + } + + // right part whose number of characters is just over 50 + int right_index = str.size() - num_displayed_character; + pre_condition = IsLetterNumberOrUnderline(str.at(right_index)); + for (; right_index >= 0; right_index--) { + bool cur_condition = IsLetterNumberOrUnderline(str.at(right_index)); + if ((pre_condition && !cur_condition) || (!pre_condition && cur_condition)) { + right_index++; + break; } } - // A string of more than 150 characters - if (first_index == -1 && last_index == -1) { return str; } - CHECK_OR_RETURN(first_index <= str.size()); - CHECK_OR_RETURN(last_index <= str.size()); + // a long word of more than 150 + if (right_index - left_index < 50) { return str; } std::stringstream ss; - // The number of characters before the first word exceeds 50 - if (first_index == -1) { - ss << " ... " << str.substr(last_index); - } - // The number of characters after the last word exceeds 50 - else if (last_index == -1) { - ss << str.substr(0, first_index) << " ... "; - } else { - ss << str.substr(0, first_index) << " ... " << str.substr(last_index); - } + CHECK_OR_RETURN(left_index >= 0); + CHECK_OR_RETURN(left_index < str.size()); + ss << str.substr(0, left_index); + ss << " ... "; + CHECK_OR_RETURN(right_index >= 0); + CHECK_OR_RETURN(right_index < str.size()); + ss << str.substr(right_index); return ss.str(); } -std::string FormatFile(const std::string& file) { +// file info in stack frame +std::string FormatFileOfStackFrame(const std::string& file) { std::stringstream ss; ss << "\n File \"" << file << "\", "; return ss.str(); } -std::string FormatLine(const int64_t& line) { +// line info in stack frame +std::string FormatLineOfStackFrame(const int64_t& line) { std::stringstream ss; ss << "line " << line << ","; return ss.str(); } -std::string FormatFunction(const std::string& function) { +// function info in stack frame +std::string FormatFunctionOfStackFrame(const std::string& function) { std::stringstream ss; ss << " in " << function; return ss.str(); } -Maybe FormatErrorMsg(std::string error_msg, bool is_last_stack_frame) { +// msg in stack frame +Maybe FormatMsgOfStackFrame(std::string error_msg, bool is_last_stack_frame) { error_msg = StripBrackets(error_msg); - if (!is_last_stack_frame) { error_msg = *JUST(ShortenErrorMsg(error_msg)); } + if (!is_last_stack_frame) { error_msg = *JUST(ShortenMsg(error_msg)); } // error_msg of last stack frame come from "<<" if (is_last_stack_frame) { error_msg = StripSpace(error_msg); } std::stringstream ss; @@ -108,10 +114,29 @@ Maybe FormatErrorMsg(std::string error_msg, bool is_last_stack_fram return ss.str(); } -std::string FormatErrorSummaryAndMsg(const std::shared_ptr& error) { +// the error_summary and msg in error proto +std::string FormatErrorSummaryAndMsgOfErrorProto(const std::shared_ptr& error) { std::stringstream ss; if (error->has_error_summary()) { ss << error->error_summary(); } - if (error->has_msg()) { ss << (ss.str().size() != 0 ? ", " + error->msg() : error->msg()); } + if (error->has_msg()) { ss << (ss.str().size() != 0 ? "\n" + error->msg() : error->msg()); } + return ss.str(); +} + +// the msg in error type instance. +Maybe FormatMsgOfErrorType(const std::shared_ptr& error) { + CHECK_NE_OR_RETURN(error->error_type_case(), cfg::ErrorProto::ERROR_TYPE_NOT_SET); + std::stringstream ss; + ErrorProto pb_error; + error->ToProto(&pb_error); + const google::protobuf::Descriptor* pb_error_des = pb_error.GetDescriptor(); + const google::protobuf::OneofDescriptor* oneof_field_des = + pb_error_des->FindOneofByName("error_type"); + const google::protobuf::Reflection* pb_error_ref = pb_error.GetReflection(); + const google::protobuf::FieldDescriptor* field_des = + pb_error_ref->GetOneofFieldDescriptor(pb_error, oneof_field_des); + CHECK_OR_RETURN(field_des != nullptr); + const google::protobuf::Message& error_type = pb_error_ref->GetMessage(pb_error, field_des); + ss << error_type.DebugString(); return ss.str(); } @@ -119,14 +144,23 @@ std::string FormatErrorSummaryAndMsg(const std::shared_ptr& err Maybe FormatErrorStr(const std::shared_ptr& error) { std::stringstream ss; + // Get msg from stack frame of error proto for (auto stack_frame = error->mutable_stack_frame()->rbegin(); stack_frame < error->mutable_stack_frame()->rend(); stack_frame++) { - ss << FormatFile(*stack_frame->mutable_file()) << FormatLine(*stack_frame->mutable_line()) - << FormatFunction(*stack_frame->mutable_function()) - << *JUST(FormatErrorMsg(*stack_frame->mutable_error_msg(), - stack_frame == error->mutable_stack_frame()->rend() - 1)); + ss << FormatFileOfStackFrame(*stack_frame->mutable_file()) + << FormatLineOfStackFrame(*stack_frame->mutable_line()) + << FormatFunctionOfStackFrame(*stack_frame->mutable_function()) + << *JUST(FormatMsgOfStackFrame(*stack_frame->mutable_error_msg(), + stack_frame == error->mutable_stack_frame()->rend() - 1)); + } + // Get msg from error summary and msg of error proto + std::string error_summary_and_msg_of_error_proto = FormatErrorSummaryAndMsgOfErrorProto(error); + if (error_summary_and_msg_of_error_proto.size() != 0) { + ss << "\n" << error_summary_and_msg_of_error_proto; } - ss << "\n" << FormatErrorSummaryAndMsg(error); + // Get msg from error type of error proto + std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error)); + if (msg_of_error_type.size() != 0) { ss << "\n" << msg_of_error_type; } return ss.str(); } diff --git a/oneflow/core/common/error_util.h b/oneflow/core/common/error_util.h index aa25886dd9a..d93de56c98d 100644 --- a/oneflow/core/common/error_util.h +++ b/oneflow/core/common/error_util.h @@ -18,14 +18,13 @@ limitations under the License. #include #include "oneflow/core/common/error.cfg.h" +#include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/maybe.h" namespace oneflow { namespace cfg { class ErrorProto; } -std::string* MutErrorStr(); -const std::string& GetErrorStr(); Maybe FormatErrorStr(const std::shared_ptr& error); } // namespace oneflow