Skip to content

Commit ae67157

Browse files
hehhjiangtensorflower-gardener
authored andcommitted
Extract the model loading logic as a ModelLoader class.
Also move the Validator::CheckModel() functionality to this new ModelLoader class. PiperOrigin-RevId: 460166696
1 parent ceb1bf0 commit ae67157

File tree

11 files changed

+379
-164
lines changed

11 files changed

+379
-164
lines changed

tensorflow/lite/experimental/acceleration/mini_benchmark/BUILD

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,11 @@ cc_library(
454454
deps = [
455455
":call",
456456
":decode_jpeg",
457+
":model_loader",
457458
":status_codes",
458459
"//tensorflow/lite:framework",
459460
"//tensorflow/lite:minimal_logging",
461+
"//tensorflow/lite/c:common",
460462
"//tensorflow/lite/core/api",
461463
"//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
462464
"//tensorflow/lite/experimental/acceleration/configuration:delegate_registry",
@@ -471,15 +473,19 @@ cc_library(
471473
hdrs = ["validator_runner.h"],
472474
deps = [
473475
":fb_storage",
476+
":model_loader",
474477
":runner",
475478
":status_codes",
476479
":validator",
480+
"@flatbuffers",
477481
"//tensorflow/lite:minimal_logging",
478482
"//tensorflow/lite/core/api",
479483
"//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
480-
"//tensorflow/lite/nnapi/sl:nnapi_support_library",
484+
# For NNAPI support library, the headears and source files are defined
485+
# as two separate targets. We need to include both targets for NNAPI to
486+
# be invoked.
487+
"//tensorflow/lite/nnapi/sl:nnapi_support_library", # buildcleaner: keep
481488
"//tensorflow/lite/nnapi/sl:nnapi_support_library_headers",
482-
"@flatbuffers",
483489
],
484490
)
485491

@@ -493,17 +499,20 @@ cc_library(
493499
srcs = ["validator_runner_entrypoint.cc"],
494500
deps = [
495501
":fb_storage",
496-
":runner",
502+
":model_loader",
497503
":set_big_core_affinity_h",
498504
":status_codes",
499505
":validator",
500506
":validator_runner",
507+
"@com_google_absl//absl/strings",
508+
"@flatbuffers",
501509
"//tensorflow/lite/core/api",
502510
"//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
503-
"//tensorflow/lite/nnapi/sl:nnapi_support_library",
511+
# For NNAPI support library, the headears and source files are defined
512+
# as two separate targets. We need to include both targets for NNAPI to
513+
# be invoked.
514+
"//tensorflow/lite/nnapi/sl:nnapi_support_library", # buildcleaner: keep
504515
"//tensorflow/lite/nnapi/sl:nnapi_support_library_headers",
505-
"@com_google_absl//absl/strings",
506-
"@flatbuffers",
507516
],
508517
)
509518

@@ -709,6 +718,32 @@ cc_test(
709718
],
710719
)
711720

721+
cc_library(
722+
name = "model_loader",
723+
srcs = ["model_loader.cc"],
724+
hdrs = ["model_loader.h"],
725+
deps = [
726+
":status_codes",
727+
"//tensorflow/lite:allocation",
728+
"//tensorflow/lite:model_builder",
729+
"//tensorflow/lite:stderr_reporter",
730+
"@com_google_absl//absl/strings",
731+
],
732+
)
733+
734+
cc_test(
735+
name = "model_loader_test",
736+
srcs = ["model_loader_test.cc"],
737+
deps = [
738+
":embedded_mobilenet_model",
739+
":mini_benchmark_test_helper",
740+
":model_loader",
741+
":status_codes",
742+
"@com_google_absl//absl/strings:str_format",
743+
"@com_google_googletest//:gtest_main",
744+
],
745+
)
746+
712747
#
713748
# Test targets for separate process.
714749
# Unit tests using cc_test and turned into Android tests with tflite_portable_test_suite().
@@ -770,11 +805,13 @@ cc_binary(
770805
linkshared = True,
771806
deps = [
772807
":fb_storage",
808+
":model_loader",
773809
":runner",
774810
":status_codes",
775811
":set_big_core_affinity",
776812
":validator",
777813
":validator_runner",
814+
"@com_google_absl//absl/strings",
778815
"@flatbuffers",
779816
"//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin",
780817
"//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
@@ -844,6 +881,7 @@ cc_test(
844881
":embedded_mobilenet_validation_model",
845882
":embedded_mobilenet_model",
846883
":mini_benchmark_test_helper",
884+
":model_loader",
847885
":status_codes",
848886
":validator",
849887
"@com_google_googletest//:gtest_main",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h"
16+
17+
#include <memory>
18+
#include <string>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "absl/strings/numbers.h"
23+
#include "absl/strings/str_split.h"
24+
#include "tensorflow/lite/allocation.h"
25+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
26+
#include "tensorflow/lite/model_builder.h"
27+
#include "tensorflow/lite/stderr_reporter.h"
28+
29+
namespace tflite {
30+
namespace acceleration {
31+
32+
std::unique_ptr<ModelLoader> ModelLoader::CreateFromFdOrPath(
33+
absl::string_view fd_or_path) {
34+
if (!absl::StartsWith(fd_or_path, "fd:")) {
35+
return std::make_unique<ModelLoader>(fd_or_path);
36+
}
37+
38+
std::vector<std::string> parts = absl::StrSplit(fd_or_path, ':');
39+
int model_fd;
40+
size_t model_offset, model_size;
41+
if (parts.size() != 4 || !absl::SimpleAtoi(parts[1], &model_fd) ||
42+
!absl::SimpleAtoi(parts[2], &model_offset) ||
43+
!absl::SimpleAtoi(parts[3], &model_size)) {
44+
return nullptr;
45+
}
46+
return std::make_unique<ModelLoader>(model_fd, model_offset, model_size);
47+
}
48+
49+
MinibenchmarkStatus ModelLoader::Init() {
50+
if (model_) {
51+
// Already done.
52+
return kMinibenchmarkSuccess;
53+
}
54+
if (model_path_.empty() && model_fd_ <= 0) {
55+
return kMinibenchmarkPreconditionNotMet;
56+
}
57+
if (!model_path_.empty()) {
58+
model_ = FlatBufferModel::VerifyAndBuildFromFile(model_path_.c_str());
59+
} else if (MMAPAllocation::IsSupported()) {
60+
auto allocation = std::make_unique<MMAPAllocation>(
61+
model_fd_, model_offset_, model_size_, tflite::DefaultErrorReporter());
62+
if (!allocation->valid()) {
63+
return kMinibenchmarkModelReadFailed;
64+
}
65+
model_ =
66+
FlatBufferModel::VerifyAndBuildFromAllocation(std::move(allocation));
67+
} else {
68+
return kMinibenchmarkUnsupportedPlatform;
69+
}
70+
if (!model_) {
71+
return kMinibenchmarkModelBuildFailed;
72+
}
73+
return kMinibenchmarkSuccess;
74+
}
75+
76+
} // namespace acceleration
77+
} // namespace tflite
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
16+
#define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
17+
18+
#include <stddef.h>
19+
#include <unistd.h>
20+
21+
#include <memory>
22+
#include <string>
23+
24+
#include "absl/strings/string_view.h"
25+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
26+
#include "tensorflow/lite/model_builder.h"
27+
28+
namespace tflite {
29+
namespace acceleration {
30+
31+
// Class to load the Model.
32+
class ModelLoader {
33+
public:
34+
// Create the model loader from a model_path or a file descriptor. File
35+
// descriptor path must be in the format of
36+
// "fd:%model_fd%:%model_offset%:%model_size%". Return nullptr if the path
37+
// starts with "fd:" but cannot be parsed with the given format.
38+
static std::unique_ptr<ModelLoader> CreateFromFdOrPath(
39+
absl::string_view fd_or_path);
40+
41+
// Create the model loader from model_path.
42+
explicit ModelLoader(absl::string_view model_path)
43+
: model_path_(model_path) {}
44+
45+
#ifndef _WIN32
46+
// Create the model loader from file descriptor. The model_fd only has to be
47+
// valid for the duration of the constructor (it's dup'ed inside). This
48+
// constructor is not available on Windows.
49+
ModelLoader(int model_fd, size_t model_offset, size_t model_size)
50+
: model_fd_(dup(model_fd)),
51+
model_offset_(model_offset),
52+
model_size_(model_size) {}
53+
#endif // !_WIN32
54+
55+
~ModelLoader() {
56+
if (model_fd_ >= 0) {
57+
close(model_fd_);
58+
}
59+
}
60+
61+
// Return whether the model is loaded successfully.
62+
MinibenchmarkStatus Init();
63+
64+
const FlatBufferModel* GetModel() const { return model_.get(); }
65+
66+
private:
67+
const std::string model_path_;
68+
const int model_fd_ = -1;
69+
const size_t model_offset_ = 0;
70+
const size_t model_size_ = 0;
71+
std::unique_ptr<FlatBufferModel> model_;
72+
};
73+
74+
} // namespace acceleration
75+
76+
} // namespace tflite
77+
78+
#endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/model_loader.h"
16+
17+
#include <fcntl.h>
18+
#include <sys/stat.h>
19+
20+
#include <memory>
21+
#include <string>
22+
23+
#include <gmock/gmock.h>
24+
#include <gtest/gtest.h>
25+
#include "absl/strings/str_format.h"
26+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/embedded_mobilenet_model.h"
27+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark_test_helper.h"
28+
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
29+
30+
namespace tflite {
31+
namespace acceleration {
32+
namespace {
33+
34+
class ModelLoaderTest : public ::testing::Test {
35+
protected:
36+
void SetUp() override {
37+
model_path_ = MiniBenchmarkTestHelper::DumpToTempFile(
38+
"mobilenet_quant.tflite",
39+
g_tflite_acceleration_embedded_mobilenet_model,
40+
g_tflite_acceleration_embedded_mobilenet_model_len);
41+
}
42+
std::string model_path_;
43+
};
44+
45+
TEST_F(ModelLoaderTest, CreateFromModelPath) {
46+
std::unique_ptr<ModelLoader> model_loader =
47+
ModelLoader::CreateFromFdOrPath(model_path_);
48+
ASSERT_NE(model_loader, nullptr);
49+
EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
50+
}
51+
52+
TEST_F(ModelLoaderTest, CreateFromFdPath) {
53+
int fd = open(model_path_.c_str(), O_RDONLY);
54+
ASSERT_GE(fd, 0);
55+
struct stat stat_buf = {0};
56+
ASSERT_EQ(fstat(fd, &stat_buf), 0);
57+
auto model_loader = std::make_unique<ModelLoader>(fd, 0, stat_buf.st_size);
58+
close(fd);
59+
60+
ASSERT_NE(model_loader, nullptr);
61+
EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
62+
}
63+
64+
TEST_F(ModelLoaderTest, CreateFromFdOrModelPath) {
65+
int fd = open(model_path_.c_str(), O_RDONLY);
66+
ASSERT_GE(fd, 0);
67+
struct stat stat_buf = {0};
68+
ASSERT_EQ(fstat(fd, &stat_buf), 0);
69+
std::string path = absl::StrFormat("fd:%d:%zu:%zu", fd, 0, stat_buf.st_size);
70+
auto model_loader = ModelLoader::CreateFromFdOrPath(path);
71+
close(fd);
72+
73+
ASSERT_NE(model_loader, nullptr);
74+
EXPECT_THAT(model_loader->Init(), kMinibenchmarkSuccess);
75+
}
76+
77+
TEST_F(ModelLoaderTest, InvalidFdPath) {
78+
int fd = open(model_path_.c_str(), O_RDONLY);
79+
ASSERT_GE(fd, 0);
80+
struct stat stat_buf = {0};
81+
ASSERT_EQ(fstat(fd, &stat_buf), 0);
82+
std::string path = absl::StrFormat("fd:%d:%zu", fd, 0);
83+
auto model_loader = ModelLoader::CreateFromFdOrPath(path);
84+
close(fd);
85+
86+
EXPECT_EQ(model_loader, nullptr);
87+
}
88+
89+
} // namespace
90+
} // namespace acceleration
91+
} // namespace tflite

tensorflow/lite/experimental/acceleration/mini_benchmark/model_validation_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ class LocalizerValidationRegressionTest : public ::testing::Test {
105105
ASSERT_GE(fd, 0);
106106
struct stat stat_buf = {0};
107107
ASSERT_EQ(fstat(fd, &stat_buf), 0);
108-
auto validator =
109-
std::make_unique<Validator>(fd, 0, stat_buf.st_size, settings);
108+
auto validator = std::make_unique<Validator>(
109+
std::make_unique<ModelLoader>(fd, /*offset=*/0, stat_buf.st_size),
110+
settings);
110111
close(fd);
111112

112113
Validator::Results results;

0 commit comments

Comments
 (0)