diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 43a33cbea6e1..cc6d2d97bf3f 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/strings/strip.h" + #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -44,22 +46,15 @@ const int kRightMargin = 79; // Converts: // bazel-out/.../genfiles/(external/YYY/)?XX // to: XX. -string GetPath(const string& dot_h_fname) { - auto pos = dot_h_fname.find("/genfiles/"); - string result = dot_h_fname; - if (pos != string::npos) { - // - 1 account for the terminating null character (\0) in "/genfiles/". - result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); - } - if (result.size() > sizeof("external/") && - result.compare(0, sizeof("external/") - 1, "external/") == 0) { - result = result.substr(sizeof("external/") - 1); - pos = result.find("/"); +string GetPath(const string& dot_h_fname, const string& genfiles_dir) { + absl::string_view result = absl::StripPrefix(dot_h_fname, genfiles_dir); + if (absl::ConsumePrefix(&result, "external/")) { + auto pos = result.find("/"); if (pos != string::npos) { result = result.substr(pos + 1); } } - return result; + return string(result); } // Converts: some/path/to/file.xx @@ -76,7 +71,7 @@ string GetFilename(const string& path) { // cc/ops/gen_foo_ops.h // to: // CC_OPS_GEN_FOO_OPS_H_ -string ToGuard(const string& path) { +string ToGuard(absl::string_view path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { @@ -1011,7 +1006,7 @@ void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def, } void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h, - WritableFile* cc, string* op_header_guard) { + WritableFile* cc, string* op_header_guard, const string& genfiles_dir) { const string header = R"header(// This file is MACHINE GENERATED! Do not edit. @@ -1038,7 +1033,7 @@ namespace ops { )namespace"; - const string op_header = GetPath(dot_h_fname); + const string op_header = GetPath(dot_h_fname, genfiles_dir); *op_header_guard = ToGuard(op_header); const string cc_header = strings::StrCat( R"include(// This file is MACHINE GENERATED! Do not edit. @@ -1100,7 +1095,8 @@ string MakeInternal(const string& fname) { } // namespace void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname) { + const string& dot_h_fname, const string& dot_cc_fname, + const string& genfiles_dir) { Env* env = Env::Default(); // Write the initial boilerplate to the .h and .cc files. @@ -1109,7 +1105,7 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); string op_header_guard; - StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard); + StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard, genfiles_dir); // Create the internal versions of these files for the hidden ops. std::unique_ptr internal_h = nullptr; @@ -1119,7 +1115,7 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc)); string internal_op_header_guard; StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(), - internal_cc.get(), &internal_op_header_guard); + internal_cc.get(), &internal_op_header_guard, genfiles_dir); for (const auto& graph_op_def : ops.op()) { // Skip deprecated ops. diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index c7256a7dc384..6b239262905c 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -24,7 +24,8 @@ namespace tensorflow { /// Result is written to files dot_h and dot_cc. void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname); + const string& dot_h_fname, const string& dot_cc_fname, + const string& genfiles_dir); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 3157792e15a0..51de1a6c97d9 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -29,7 +29,8 @@ namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, bool include_internal, - const std::vector& api_def_dirs) { + const std::vector& api_def_dirs, + const std::string& genfiles_dir) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); ApiDefMap api_def_map(ops); @@ -49,7 +50,7 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, api_def_map.UpdateDocs(); - WriteCCOps(ops, api_def_map, dot_h, dot_cc); + WriteCCOps(ops, api_def_map, dot_h, dot_cc, genfiles_dir); } } // namespace @@ -57,13 +58,13 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 5) { + if (argc != 6) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, "Usage: %s out.h out.cc include_internal " - "api_def_dirs1,api_def_dir2 ...\n" + "api_def_dirs1,api_def_dir2... genfiles_dir\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); @@ -72,6 +73,6 @@ int main(int argc, char* argv[]) { bool include_internal = tensorflow::StringPiece("1") == argv[3]; std::vector api_def_dirs = tensorflow::str_util::Split( argv[4], ",", tensorflow::str_util::SkipEmpty()); - tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs); + tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs, argv[5]); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 5d9dfd95a553..68c436f5bc5a 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -93,7 +93,7 @@ void GenerateCcOpFiles(Env* env, const OpList& ops, const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); - WriteCCOps(ops, api_def_map, h_file_path, cc_file_path); + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, ""); TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); TF_ASSERT_OK( diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 08fccaa91d32..c9813080fff2 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -700,7 +700,7 @@ def tf_gen_op_wrapper_cc( tools = [":" + tool] + tf_binary_additional_srcs(), cmd = ("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " + "$(location :" + out_ops_file + ".cc) " + - str(include_internal_ops) + " " + api_def_args_str), + str(include_internal_ops) + " " + api_def_args_str + " $(GENDIR)/"), ) # Given a list of "op_lib_names" (a list of files in the ops directory