Skip to content

Commit

Permalink
Refactor class Network
Browse files Browse the repository at this point in the history
That class is an abstract class with several pure virtual functions.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Feb 26, 2019
1 parent cf85054 commit 98dd3b6
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 59 deletions.
7 changes: 5 additions & 2 deletions src/lstm/convolve.h
Expand Up @@ -4,7 +4,6 @@
// and pulls in random data to fill out-of-input inputs.
// Output is therefore same size as its input, but deeper.
// Author: Ray Smith
// Created: Tue Mar 18 16:45:34 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -61,6 +60,11 @@ class Convolve : public Network {
NetworkScratch* scratch,
NetworkIO* back_deltas) override;

private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

protected:
// Serialized data.
int32_t half_x_;
Expand All @@ -69,5 +73,4 @@ class Convolve : public Network {

} // namespace tesseract.


#endif // TESSERACT_LSTM_SUBSAMPLE_H_
5 changes: 4 additions & 1 deletion src/lstm/input.h
Expand Up @@ -2,7 +2,6 @@
// File: input.h
// Description: Input layer class for neural network implementations.
// Author: Ray Smith
// Created: Thu Mar 13 08:56:26 PDT 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -93,6 +92,10 @@ class Input : public Network {
TRand* randomizer, NetworkIO* input);

private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

// Input shape determines how images are dealt with.
StaticShape shape_;
// Cached total network x scale factor for scaling bounding boxes.
Expand Down
80 changes: 43 additions & 37 deletions src/lstm/network.cpp
Expand Up @@ -2,7 +2,6 @@
// File: network.cpp
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 17:25:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -53,10 +52,11 @@ const int kMaxWinSize = 2000;
const int kXWinFrameSize = 30;
const int kYWinFrameSize = 80;

// String names corresponding to the NetworkType enum. Keep in sync.
// String names corresponding to the NetworkType enum.
// Keep in sync with NetworkType.
// Names used in Serialization to allow re-ordering/addition/deletion of
// layer types in NetworkType without invalidating existing network files.
char const* const Network::kTypeNames[NT_COUNT] = {
static char const* const kTypeNames[NT_COUNT] = {
"Invalid", "Input",
"Convolve", "Maxpool",
"Parallel", "Replicated",
Expand Down Expand Up @@ -165,81 +165,87 @@ bool Network::Serialize(TFile* fp) const {
return true;
}

// Reads from the given file. Returns false in case of error.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
bool Network::DeSerialize(TFile* fp) {
static NetworkType getNetworkType(TFile* fp) {
int8_t data;
if (!fp->DeSerialize(&data)) return false;
if (!fp->DeSerialize(&data)) return NT_NONE;
if (data == NT_NONE) {
STRING type_name;
if (!type_name.DeSerialize(fp)) return false;
if (!type_name.DeSerialize(fp)) return NT_NONE;
for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
}
if (data == NT_COUNT) {
tprintf("Invalid network layer type:%s\n", type_name.string());
return false;
return NT_NONE;
}
}
type_ = static_cast<NetworkType>(data);
if (!fp->DeSerialize(&data)) return false;
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (!fp->DeSerialize(&data)) return false;
needs_to_backprop_ = data != 0;
if (!fp->DeSerialize(&network_flags_)) return false;
if (!fp->DeSerialize(&ni_)) return false;
if (!fp->DeSerialize(&no_)) return false;
if (!fp->DeSerialize(&num_weights_)) return false;
if (!name_.DeSerialize(fp)) return false;
return true;
return static_cast<NetworkType>(data);
}

// Reads from the given file. Returns nullptr in case of error.
// Determines the type of the serialized class and calls its DeSerialize
// on a new object of the appropriate type, which is returned.
Network* Network::CreateFromFile(TFile* fp) {
Network stub;
if (!stub.DeSerialize(fp)) return nullptr;
NetworkType type; // Type of the derived network class.
TrainingState training; // Are we currently training?
bool needs_to_backprop; // This network needs to output back_deltas.
int32_t network_flags; // Behavior control flags in NetworkFlags.
int32_t ni; // Number of input values.
int32_t no; // Number of output values.
int32_t num_weights; // Number of weights in this and sub-network.
STRING name; // A unique name for this layer.
int8_t data;
Network* network = nullptr;
switch (stub.type_) {
type = getNetworkType(fp);
if (!fp->DeSerialize(&data)) return nullptr;
training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (!fp->DeSerialize(&data)) return nullptr;
needs_to_backprop = data != 0;
if (!fp->DeSerialize(&network_flags)) return nullptr;
if (!fp->DeSerialize(&ni)) return nullptr;
if (!fp->DeSerialize(&no)) return nullptr;
if (!fp->DeSerialize(&num_weights)) return nullptr;
if (!name.DeSerialize(fp)) return nullptr;

switch (type) {
case NT_CONVOLVE:
network = new Convolve(stub.name_, stub.ni_, 0, 0);
network = new Convolve(name, ni, 0, 0);
break;
case NT_INPUT:
network = new Input(stub.name_, stub.ni_, stub.no_);
network = new Input(name, ni, no);
break;
case NT_LSTM:
case NT_LSTM_SOFTMAX:
case NT_LSTM_SOFTMAX_ENCODED:
case NT_LSTM_SUMMARY:
network =
new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
new LSTM(name, ni, no, no, false, type);
break;
case NT_MAXPOOL:
network = new Maxpool(stub.name_, stub.ni_, 0, 0);
network = new Maxpool(name, ni, 0, 0);
break;
// All variants of Parallel.
case NT_PARALLEL:
case NT_REPLICATED:
case NT_PAR_RL_LSTM:
case NT_PAR_UD_LSTM:
case NT_PAR_2D_LSTM:
network = new Parallel(stub.name_, stub.type_);
network = new Parallel(name, type);
break;
case NT_RECONFIG:
network = new Reconfig(stub.name_, stub.ni_, 0, 0);
network = new Reconfig(name, ni, 0, 0);
break;
// All variants of reversed.
case NT_XREVERSED:
case NT_YREVERSED:
case NT_XYTRANSPOSE:
network = new Reversed(stub.name_, stub.type_);
network = new Reversed(name, type);
break;
case NT_SERIES:
network = new Series(stub.name_);
network = new Series(name);
break;
case NT_TENSORFLOW:
#ifdef INCLUDE_TENSORFLOW
network = new TFNetwork(stub.name_);
network = new TFNetwork(name);
#else
tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
#endif
Expand All @@ -253,16 +259,16 @@ Network* Network::CreateFromFile(TFile* fp) {
case NT_LOGISTIC:
case NT_POSCLIP:
case NT_SYMCLIP:
network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
network = new FullyConnected(name, ni, no, type);
break;
default:
break;
}
if (network) {
network->training_ = stub.training_;
network->needs_to_backprop_ = stub.needs_to_backprop_;
network->network_flags_ = stub.network_flags_;
network->num_weights_ = stub.num_weights_;
network->training_ = training;
network->needs_to_backprop_ = needs_to_backprop;
network->network_flags_ = network_flags;
network->num_weights_ = num_weights;
if (!network->DeSerialize(fp)) {
delete network;
network = nullptr;
Expand Down
21 changes: 5 additions & 16 deletions src/lstm/network.h
Expand Up @@ -2,7 +2,6 @@
// File: network.h
// Description: Base class for neural network implementations.
// Author: Ray Smith
// Created: Wed May 01 16:38:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -215,17 +214,16 @@ class Network {
virtual void CacheXScaleFactor(int factor) {}

// Provides debug output on the weights.
virtual void DebugWeights() {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}
virtual void DebugWeights() = 0;

// Writes to the given file. Returns false in case of error.
// Should be overridden by subclasses, but called by their Serialize.
virtual bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
virtual bool DeSerialize(TFile* fp);
virtual bool DeSerialize(TFile* fp) = 0;

public:
// Updates the weights using the given learning rate, momentum and adam_beta.
// num_samples is used in the adam computation iff use_adam_ is true.
virtual void Update(float learning_rate, float momentum, float adam_beta,
Expand Down Expand Up @@ -261,9 +259,7 @@ class Network {
// instead of all the replicated networks having to do it.
virtual void Forward(bool debug, const NetworkIO& input,
const TransposedArray* input_transpose,
NetworkScratch* scratch, NetworkIO* output) {
tprintf("Must override Network::Forward for type %d\n", type_);
}
NetworkScratch* scratch, NetworkIO* output) = 0;

// Runs backward propagation of errors on fwdX_deltas.
// Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
Expand All @@ -272,10 +268,7 @@ class Network {
// return false from Backward!
virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
NetworkScratch* scratch,
NetworkIO* back_deltas) {
tprintf("Must override Network::Backward for type %d\n", type_);
return false;
}
NetworkIO* back_deltas) = 0;

// === Debug image display methods. ===
// Displays the image of the matrix to the forward window.
Expand Down Expand Up @@ -309,12 +302,8 @@ class Network {
ScrollView* forward_win_; // Recognition debug display window.
ScrollView* backward_win_; // Training debug display window.
TRand* randomizer_; // Random number generator.

// Static serialized name/type_ mapping. Keep in sync with NetworkType.
static char const* const kTypeNames[NT_COUNT];
};


} // namespace tesseract.

#endif // TESSERACT_LSTM_NETWORK_H_
9 changes: 6 additions & 3 deletions src/lstm/reconfig.h
Expand Up @@ -3,7 +3,6 @@
// Description: Network layer that reconfigures the scaling vs feature
// depth.
// Author: Ray Smith
// Created: Wed Feb 26 15:37:42 PST 2014
//
// (C) Copyright 2014, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,10 +15,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

#ifndef TESSERACT_LSTM_RECONFIG_H_
#define TESSERACT_LSTM_RECONFIG_H_


#include "genericvector.h"
#include "matrix.h"
#include "network.h"
Expand Down Expand Up @@ -71,6 +70,11 @@ class Reconfig : public Network {
NetworkScratch* scratch,
NetworkIO* back_deltas) override;

private:
void DebugWeights() override {
tprintf("Must override Network::DebugWeights for type %d\n", type_);
}

protected:
// Non-serialized data used to store parameters between forward and back.
StrideMap back_map_;
Expand All @@ -81,5 +85,4 @@ class Reconfig : public Network {

} // namespace tesseract.


#endif // TESSERACT_LSTM_SUBSAMPLE_H_

0 comments on commit 98dd3b6

Please sign in to comment.