Skip to content

Commit 27521e9

Browse files
committed
second ci
1 parent 0ae0466 commit 27521e9

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/fnn.cc

+16-5
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,28 @@ int main(int argc, char* argv[]) {
5656
// when using `bazel run` since the cwd isn't where you call
5757
// `bazel run` but from inside a temp folder.)
5858
GraphDef graph_def;
59-
status = ReadBinaryProto(Env::Default(), "../demo/model/graph.pb", &graph_def);
59+
std::string graph_path = argv[1];
60+
status = ReadBinaryProto(Env::Default(), graph_path, &graph_def);
6061
if (!status.ok()) {
61-
std::cout << status.ToString() << "\n";
62-
return 1;
62+
throw runtime_error("Error loading graph from " + graph_path + ": " + status.ToString());
6363
}
6464

6565
// Add the graph to the session
6666
status = session->Create(graph_def);
6767
if (!status.ok()) {
68-
std::cout << status.ToString() << "\n";
69-
return 1;
68+
throw runtime_error("Error set graph to session: " + status.ToString());
69+
}
70+
71+
// Read parameters from the saved checkpoint
72+
Tensor checkpointPathTensor(DT_STRING, TensorShape());
73+
checkpointPathTensor.scalar<std::string>()() = "../demo/model";
74+
status = session->Run(
75+
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
76+
{},
77+
{graph_def.saver_def().restore_op_name()},
78+
nullptr);
79+
if (!status.ok()) {
80+
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
7081
}
7182

7283
// Setup inputs and outputs:

0 commit comments

Comments
 (0)