Skip to content

Commit

Permalink
Remove oneDNN dependency from tf_runtime since it clashes with Tensor…
Browse files Browse the repository at this point in the history
…Flow's oneDNN dependency when tf_runtime is included in TensorFlow.

PiperOrigin-RevId: 627873560
  • Loading branch information
penpornk authored and Copybara-Service committed Apr 24, 2024
1 parent 2b13a98 commit 17c8497
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 53 deletions.
35 changes: 2 additions & 33 deletions backends/common/BUILD
@@ -1,4 +1,3 @@
load("@bazel_skylib//lib:selects.bzl", "selects")
load(
"@tf_runtime//:build_defs.bzl",
"if_google",
Expand All @@ -13,24 +12,6 @@ package(

licenses(["notice"])

config_setting(
name = "disable_mkldnn",
flag_values = {"@tf_runtime//:eigen_mkldnn_contraction_kernel": "False"},
)

config_setting(
name = "enable_mkldnn",
flag_values = {"@tf_runtime//:eigen_mkldnn_contraction_kernel": "True"},
)

selects.config_setting_group(
name = "use_mkldnn",
match_all = [
":enable_mkldnn",
"@tf_runtime//:linux_x86_64",
],
)

tfrt_cc_library(
name = "tf_metadata_functions",
srcs = ["lib/ops/tf/metadata_functions.cc"],
Expand Down Expand Up @@ -135,14 +116,7 @@ tfrt_cc_library(
"EIGEN_MUTEX_LOCK=std::unique_lock<std::mutex>",
"EIGEN_CONDVAR=std::condition_variable",
"EIGEN_AVOID_STL_ARRAY",
]) + select({
":use_mkldnn": [
# Custom contraction kernel defines.
"TFRT_EIGEN_USE_CUSTOM_CONTRACTION_KERNEL",
"TFRT_EIGEN_USE_MKLDNN_CONTRACTION_KERNEL",
],
"//conditions:default": [],
}),
]),
visibility = ["//visibility:public"],
deps = [
"@eigen_archive//:eigen3",
Expand All @@ -155,12 +129,7 @@ tfrt_cc_library(
# TODO(b/161569340): Short-term fix. Remove.
"//third_party/tensorflow/core/platform:types",
"//third_party/tensorflow/core/platform:mutex",
]) + select({
":use_mkldnn": [
"@dnnl//:dnnl_single_threaded",
],
"//conditions:default": [],
}),
]),
)

tfrt_cc_library(
Expand Down
13 changes: 9 additions & 4 deletions backends/cpu/BUILD
@@ -1,10 +1,18 @@
load("@tf_runtime//:build_defs.bzl", "if_google", "tfrt_cc_library")
# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license")
# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library")

package(
default_visibility = [":__subpackages__"],
)

# copybara:uncomment_begin(Internal license rules)
# license(
# name = "license",
# package_name = "cpu",
# )
# copybara:uncomment_end

licenses(["notice"])

tfrt_cc_library(
Expand Down Expand Up @@ -159,10 +167,7 @@ tfrt_cc_library(
"@tf_runtime//:tensor",
"@tf_runtime//backends/common:eigencompat",
"@tf_runtime//backends/common:tf_bcast",
] + select({
"@tf_runtime//:linux_x86_64": ["@dnnl//:dnnl_single_threaded"],
"//conditions:default": [],
}),
],
)

# copybara:uncomment_begin
Expand Down
4 changes: 0 additions & 4 deletions backends/cpu/lib/kernels/cpu_kernels.h
Expand Up @@ -37,10 +37,6 @@
#include "tfrt/tensor/scalar_host_tensor.h"
#include "tfrt/tensor/tensor_shape.h"

#ifdef __x86_64__
#include "dnnl.h" // from @dnnl
#endif

namespace tfrt {
namespace cpu {

Expand Down
12 changes: 0 additions & 12 deletions dependencies.bzl
Expand Up @@ -18,18 +18,6 @@ load("@tf_runtime//third_party:repo.bzl", "tfrt_http_archive")
def tfrt_dependencies():
"""Loads TFRT external dependencies into WORKSPACE."""

tfrt_http_archive(
name = "dnnl",
build_file = "//third_party/dnnl:BUILD",
link_files = {"//third_party/dnnl:expand_template.bzl": "expand_template.bzl"},
sha256 = "5369f7b2f0b52b40890da50c0632c3a5d1082d98325d0f2bff125d19d0dcaa1d",
strip_prefix = "oneDNN-1.6.4",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz",
"https://github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz",
],
)

tfrt_http_archive(
name = "py-cpuinfo",
strip_prefix = "py-cpuinfo-0.2.3",
Expand Down

0 comments on commit 17c8497

Please sign in to comment.