-
Notifications
You must be signed in to change notification settings - Fork 1
/
tf_cpp_binding.hh
71 lines (58 loc) · 2.14 KB
/
tf_cpp_binding.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*
** C++ binding of TensorFlow native C API to import a model and feed and
** retrieve unidimensional vectors.
**
** Tested with TensorFlow 2.3.1
**
** You can check the input names of the model by executing in a terminal:
**
** $ saved_model_cli show --dir <path_to_model_dir> --tag_set serve --signature_def serving_default
**
**
** Copyright (C) 2020 Arthur BOUTON
**
** This program is free software: you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation, version 3.
**
** This program is distributed in the hope that it will be useful, but
** WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
** General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TF_CPP_BINDING_HH
#define TF_CPP_BINDING_HH
#include "tensorflow/c/c_api.h"
#include <memory>
#include <vector>
//------------------------------------------------------------------------------------------//
// Template class to import a tensorflow model and feed and retrieve unidimensional vectors //
//------------------------------------------------------------------------------------------//
template <class T>
class TF_model
{
public:
typedef std::shared_ptr<TF_model> ptr_t;
typedef std::vector<std::vector<T>> vector_set_t;
// dim_inputs and dim_outputs list the dimensions of each expected input and output vectors:
TF_model( const char* path_to_model_dir, const std::vector<int>& dim_inputs,
const std::vector<int>& dim_ouputs );
// Infer the output(s) one set of input(s) at a time:
vector_set_t infer( vector_set_t inputs );
~TF_model();
protected:
int _n_inputs, _n_outputs;
std::vector<int> _dim_inputs, _dim_outputs;
TF_Graph* _graph;
TF_Status* _status;
TF_SessionOptions* _sessionOpts;
TF_Session* _session;
TF_Output* _model_inputs;
TF_Output* _model_outputs;
TF_Tensor** _input_tensors;
TF_Tensor** _output_tensors;
};
#endif