Skip to content

Commit

Permalink
Support build with Tensorflow
Browse files Browse the repository at this point in the history
It expects include files in /usr/include/tensorflow.

* Add configure option --with-tensorflow (disabled by default)
* Fix data type tensorflow::int64
* Remove "third_party/" in include statements
* Add dummy implementations for Backward and DebugWeights in TFNetwork
* Add files generated with protoc from tfnetwork.proto
  (so the Tensorflow sources are not needed for the build)
* Update Makefiles

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed May 24, 2019
1 parent 3f74da5 commit 4382ab1
Show file tree
Hide file tree
Showing 9 changed files with 1,717 additions and 8 deletions.
9 changes: 9 additions & 0 deletions configure.ac
Expand Up @@ -195,6 +195,15 @@ if test "$enable_opencl" = "yes"; then
])
fi

# Check whether to build with support for Tensorflow.
AC_MSG_CHECKING([--with-tensorflow])
AC_ARG_WITH([tensorflow],
AS_HELP_STRING([--with-tensorflow],
[support Tensorflow @<:@default=check@:>@]),
[], [with_tensorflow=check])
AC_MSG_RESULT([$with_tensorflow])
AM_CONDITIONAL([TENSORFLOW], [test "$with_tensorflow" != "no"])

# https://lists.apple.com/archives/unix-porting/2009/Jan/msg00026.html
m4_define([MY_CHECK_FRAMEWORK],
[AC_CACHE_CHECK([if -framework $1 works],[my_cv_framework_$1],
Expand Down
5 changes: 5 additions & 0 deletions src/api/Makefile.am
Expand Up @@ -98,3 +98,8 @@ endif
if ADD_RT
tesseract_LDADD += -lrt
endif

if TENSORFLOW
tesseract_LDADD += -lprotobuf
tesseract_LDADD += -ltensorflow_cc
endif
9 changes: 9 additions & 0 deletions src/lstm/Makefile.am
Expand Up @@ -10,6 +10,11 @@ AM_CPPFLAGS += \

AM_CXXFLAGS = $(OPENMP_CXXFLAGS)

if TENSORFLOW
AM_CPPFLAGS += -DINCLUDE_TENSORFLOW
AM_CPPFLAGS += -I/usr/include/tensorflow
endif

if !NO_TESSDATA_PREFIX
AM_CXXFLAGS += -DTESSDATA_PREFIX=@datadir@
endif
Expand Down Expand Up @@ -37,3 +42,7 @@ libtesseract_lstm_la_SOURCES = \
networkbuilder.cpp network.cpp networkio.cpp \
parallel.cpp plumbing.cpp recodebeam.cpp reconfig.cpp reversed.cpp \
series.cpp stridemap.cpp tfnetwork.cpp weightmatrix.cpp

if TENSORFLOW
libtesseract_lstm_la_SOURCES += tfnetwork.pb.cc
endif
5 changes: 2 additions & 3 deletions src/lstm/tfnetwork.cpp
Expand Up @@ -3,7 +3,6 @@
// Description: Encapsulation of an entire tensorflow graph as a
// Tesseract Network.
// Author: Ray Smith
// Created: Fri Feb 26 09:35:29 PST 2016
//
// (C) Copyright 2016, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -90,14 +89,14 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
if (!model_proto_.image_widths().empty()) {
TensorShape size_shape{1};
Tensor width_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_wtensor = width_tensor.flat<int64>();
auto eigen_wtensor = width_tensor.flat<tensorflow::int64>();
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
}
if (!model_proto_.image_heights().empty()) {
TensorShape size_shape{1};
Tensor height_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_htensor = height_tensor.flat<int64>();
auto eigen_htensor = height_tensor.flat<tensorflow::int64>();
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
}
Expand Down
18 changes: 15 additions & 3 deletions src/lstm/tfnetwork.h
Expand Up @@ -27,9 +27,9 @@

#include "network.h"
#include "static_shape.h"
#include "tfnetwork.proto.h"
#include "third_party/tensorflow/core/framework/graph.pb.h"
#include "third_party/tensorflow/core/public/session.h"
#include "tfnetwork.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/public/session.h"

namespace tesseract {

Expand Down Expand Up @@ -69,6 +69,18 @@ class TFNetwork : public Network {
NetworkScratch* scratch, NetworkIO* output) override;

private:
// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

int InitFromProto();

// The original network definition for reference.
Expand Down

0 comments on commit 4382ab1

Please sign in to comment.