Skip to content

Commit

Permalink
Explicitly pass the genfiles directory to cc_op_gen.
Browse files Browse the repository at this point in the history
This makes path computation code shorter and more robust. In particular, it makes things work under Bazel flag --incompatible_merge_genfiles_directory (bazelbuild/bazel#6761).
  • Loading branch information
benjaminp committed Apr 9, 2019
1 parent 8639637 commit 8511013
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
32 changes: 14 additions & 18 deletions tensorflow/cc/framework/cc_op_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>

#include "absl/strings/strip.h"

#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
Expand Down Expand Up @@ -44,22 +46,15 @@ const int kRightMargin = 79;
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// to: XX.
string GetPath(const string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
string result = dot_h_fname;
if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);
pos = result.find("/");
string GetPath(const string& dot_h_fname, const string& genfiles_dir) {
absl::string_view result = absl::StripPrefix(dot_h_fname, genfiles_dir);
if (absl::ConsumePrefix(&result, "external/")) {
auto pos = result.find("/");
if (pos != string::npos) {
result = result.substr(pos + 1);
}
}
return result;
return string(result);
}

// Converts: some/path/to/file.xx
Expand All @@ -76,7 +71,7 @@ string GetFilename(const string& path) {
// cc/ops/gen_foo_ops.h
// to:
// CC_OPS_GEN_FOO_OPS_H_
string ToGuard(const string& path) {
string ToGuard(absl::string_view path) {
string guard;
guard.reserve(path.size() + 1); // + 1 -> trailing _
for (const char c : path) {
Expand Down Expand Up @@ -1011,7 +1006,7 @@ void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
}

void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
WritableFile* cc, string* op_header_guard) {
WritableFile* cc, string* op_header_guard, const string& genfiles_dir) {
const string header =
R"header(// This file is MACHINE GENERATED! Do not edit.
Expand All @@ -1038,7 +1033,7 @@ namespace ops {
)namespace";

const string op_header = GetPath(dot_h_fname);
const string op_header = GetPath(dot_h_fname, genfiles_dir);
*op_header_guard = ToGuard(op_header);
const string cc_header = strings::StrCat(
R"include(// This file is MACHINE GENERATED! Do not edit.
Expand Down Expand Up @@ -1100,7 +1095,8 @@ string MakeInternal(const string& fname) {
} // namespace

void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname) {
const string& dot_h_fname, const string& dot_cc_fname,
const string& genfiles_dir) {
Env* env = Env::Default();

// Write the initial boilerplate to the .h and .cc files.
Expand All @@ -1109,7 +1105,7 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
string op_header_guard;
StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard, genfiles_dir);

// Create the internal versions of these files for the hidden ops.
std::unique_ptr<WritableFile> internal_h = nullptr;
Expand All @@ -1119,7 +1115,7 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
string internal_op_header_guard;
StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
internal_cc.get(), &internal_op_header_guard);
internal_cc.get(), &internal_op_header_guard, genfiles_dir);

for (const auto& graph_op_def : ops.op()) {
// Skip deprecated ops.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/cc/framework/cc_op_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace tensorflow {

/// Result is written to files dot_h and dot_cc.
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname);
const string& dot_h_fname, const string& dot_cc_fname,
const string& genfiles_dir);

} // namespace tensorflow

Expand Down
11 changes: 6 additions & 5 deletions tensorflow/cc/framework/cc_op_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace {

void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
bool include_internal,
const std::vector<string>& api_def_dirs) {
const std::vector<string>& api_def_dirs,
const std::string& genfiles_dir) {
OpList ops;
OpRegistry::Global()->Export(include_internal, &ops);
ApiDefMap api_def_map(ops);
Expand All @@ -49,21 +50,21 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,

api_def_map.UpdateDocs();

WriteCCOps(ops, api_def_map, dot_h, dot_cc);
WriteCCOps(ops, api_def_map, dot_h, dot_cc, genfiles_dir);
}

} // namespace
} // namespace tensorflow

int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc != 5) {
if (argc != 6) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
fprintf(stderr,
"Usage: %s out.h out.cc include_internal "
"api_def_dirs1,api_def_dir2 ...\n"
"api_def_dirs1,api_def_dir2... genfiles_dir\n"
" include_internal: 1 means include internal ops\n",
argv[0]);
exit(1);
Expand All @@ -72,6 +73,6 @@ int main(int argc, char* argv[]) {
bool include_internal = tensorflow::StringPiece("1") == argv[3];
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[4], ",", tensorflow::str_util::SkipEmpty());
tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs);
tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs, argv[5]);
return 0;
}
2 changes: 1 addition & 1 deletion tensorflow/cc/framework/cc_op_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void GenerateCcOpFiles(Env* env, const OpList& ops,
const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");

WriteCCOps(ops, api_def_map, h_file_path, cc_file_path);
WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, "");

TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
TF_ASSERT_OK(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tensorflow.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def tf_gen_op_wrapper_cc(
tools = [":" + tool] + tf_binary_additional_srcs(),
cmd = ("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
"$(location :" + out_ops_file + ".cc) " +
str(include_internal_ops) + " " + api_def_args_str),
str(include_internal_ops) + " " + api_def_args_str + " $(GENDIR)/"),
)

# Given a list of "op_lib_names" (a list of files in the ops directory
Expand Down

0 comments on commit 8511013

Please sign in to comment.