diff --git a/ccmain/linerec.cpp b/ccmain/linerec.cpp index 0a9e8fb89b..e31230840a 100644 --- a/ccmain/linerec.cpp +++ b/ccmain/linerec.cpp @@ -64,6 +64,7 @@ void Tesseract::TrainLineRecognizer(const STRING& input_imagename, return; } TrainFromBoxes(boxes, texts, block_list, &images); + images.Shuffle(); if (!images.SaveDocument(lstmf_name.string(), NULL)) { tprintf("Failed to write training data to %s!\n", lstmf_name.string()); } @@ -79,7 +80,10 @@ void Tesseract::TrainFromBoxes(const GenericVector& boxes, int box_count = boxes.size(); // Process all the text lines in this page, as defined by the boxes. int end_box = 0; - for (int start_box = 0; start_box < box_count; start_box = end_box) { + // Don't let \t, which marks newlines in the box file, get into the line + // content, as that makes the line unusable in training. + while (end_box < texts.size() && texts[end_box] == "\t") ++end_box; + for (int start_box = end_box; start_box < box_count; start_box = end_box) { // Find the textline of boxes starting at start and their bounding box. TBOX line_box = boxes[start_box]; STRING line_str = texts[start_box]; @@ -115,7 +119,9 @@ void Tesseract::TrainFromBoxes(const GenericVector& boxes, } if (imagedata != NULL) training_data->AddPageToDocument(imagedata); - if (end_box < texts.size() && texts[end_box] == "\t") ++end_box; + // Don't let \t, which marks newlines in the box file, get into the line + // content, as that makes the line unusable in training. + while (end_box < texts.size() && texts[end_box] == "\t") ++end_box; } } diff --git a/ccstruct/boxread.cpp b/ccstruct/boxread.cpp index fee0aa9aef..d6ceebb4db 100644 --- a/ccstruct/boxread.cpp +++ b/ccstruct/boxread.cpp @@ -55,6 +55,8 @@ bool ReadAllBoxes(int target_page, bool skip_blanks, const STRING& filename, GenericVector box_data; if (!tesseract::LoadDataFromFile(BoxFileName(filename), &box_data)) return false; + // Convert the array of bytes to a string, so it can be used by the parser. + box_data.push_back('\0'); return ReadMemBoxes(target_page, skip_blanks, &box_data[0], boxes, texts, box_texts, pages); } diff --git a/ccstruct/imagedata.cpp b/ccstruct/imagedata.cpp index 3f9ad33786..79de724c2c 100644 --- a/ccstruct/imagedata.cpp +++ b/ccstruct/imagedata.cpp @@ -24,18 +24,18 @@ #include "imagedata.h" +#if defined(__MINGW32__) +#include +#else +#include +#endif + #include "allheaders.h" #include "boxread.h" #include "callcpp.h" #include "helpers.h" #include "tprintf.h" -#if defined(__MINGW32__) -# include -#else -# include -#endif - // Number of documents to read ahead while training. Doesn't need to be very // large. const int kMaxReadAhead = 8; @@ -496,6 +496,21 @@ inT64 DocumentData::UnCache() { return memory_saved; } +// Shuffles all the pages in the document. +void DocumentData::Shuffle() { + TRand random; + // Different documents get shuffled differently, but the same for the same + // name. + random.set_seed(document_name_.string()); + int num_pages = pages_.size(); + // Execute one random swap for each page in the document. + for (int i = 0; i < num_pages; ++i) { + int src = random.IntRand() % num_pages; + int dest = random.IntRand() % num_pages; + std::swap(pages_[src], pages_[dest]); + } +} + // Locks the pages_mutex_ and Loads as many pages can fit in max_memory_ // starting at index pages_offset_. bool DocumentData::ReCachePages() { diff --git a/ccstruct/imagedata.h b/ccstruct/imagedata.h index ae6722934e..45cb65a6c5 100644 --- a/ccstruct/imagedata.h +++ b/ccstruct/imagedata.h @@ -266,6 +266,8 @@ class DocumentData { // Removes all pages from memory and frees the memory, but does not forget // the document metadata. Returns the memory saved. inT64 UnCache(); + // Shuffles all the pages in the document. + void Shuffle(); private: // Sets the value of total_pages_ behind a mutex. diff --git a/ccstruct/pageres.cpp b/ccstruct/pageres.cpp index 330dd22915..b981fb1a1c 100644 --- a/ccstruct/pageres.cpp +++ b/ccstruct/pageres.cpp @@ -529,13 +529,12 @@ void WERD_RES::FilterWordChoices(int debug_level) { if (choice->unichar_id(i) != best_choice->unichar_id(j) && choice->certainty(i) - best_choice->certainty(j) < threshold) { if (debug_level >= 2) { - STRING label; - label.add_str_int("\nDiscarding bad choice #", index); - choice->print(label.string()); - tprintf("i %d j %d Chunk %d Choice->Blob[i].Certainty %.4g" - " BestChoice->ChunkCertainty[Chunk] %g Threshold %g\n", - i, j, chunk, choice->certainty(i), - best_choice->certainty(j), threshold); + choice->print("WorstCertaintyDiffWorseThan"); + tprintf( + "i %d j %d Choice->Blob[i].Certainty %.4g" + " WorstOtherChoiceCertainty %g Threshold %g\n", + i, j, choice->certainty(i), best_choice->certainty(j), threshold); + tprintf("Discarding bad choice #%d\n", index); } delete it.extract(); break; diff --git a/ccutil/genericvector.h b/ccutil/genericvector.h index 3a70e21ce0..f8907ca040 100644 --- a/ccutil/genericvector.h +++ b/ccutil/genericvector.h @@ -363,8 +363,7 @@ inline bool LoadDataFromFile(const STRING& filename, fseek(fp, 0, SEEK_END); size_t size = ftell(fp); fseek(fp, 0, SEEK_SET); - // Pad with a 0, just in case we treat the result as a string. - data->init_to_size(static_cast(size) + 1, 0); + data->init_to_size(static_cast(size), 0); bool result = fread(&(*data)[0], 1, size, fp) == size; fclose(fp); return result; @@ -380,6 +379,17 @@ inline bool SaveDataToFile(const GenericVector& data, fclose(fp); return result; } +// Reads a file as a vector of STRING. +inline bool LoadFileLinesToStrings(const STRING& filename, + GenericVector* lines) { + GenericVector data; + if (!LoadDataFromFile(filename.string(), &data)) { + return false; + } + STRING lines_str(&data[0], data.size()); + lines_str.split('\n', lines); + return true; +} template bool cmp_eq(T const & t1, T const & t2) { diff --git a/ccutil/helpers.h b/ccutil/helpers.h index a2276bc451..33ffd6c46f 100644 --- a/ccutil/helpers.h +++ b/ccutil/helpers.h @@ -27,6 +27,8 @@ #include #include +#include +#include #include "host.h" @@ -43,6 +45,11 @@ class TRand { void set_seed(uinT64 seed) { seed_ = seed; } + // Sets the seed using a hash of a string. + void set_seed(const std::string& str) { + std::hash hasher; + set_seed(static_cast(hasher(str))); + } // Returns an integer in the range 0 to MAX_INT32. inT32 IntRand() { diff --git a/lstm/fullyconnected.cpp b/lstm/fullyconnected.cpp index 77406b6208..c5b92768e5 100644 --- a/lstm/fullyconnected.cpp +++ b/lstm/fullyconnected.cpp @@ -56,6 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const { return result; } +// Suspends/Enables training by setting the training_ flag. Serialize and +// DeSerialize only operate on the run-time data if state is false. +void FullyConnected::SetEnableTraining(TrainingState state) { + if (state == TS_RE_ENABLE) { + if (training_ == TS_DISABLED) weights_.InitBackward(false); + training_ = TS_ENABLED; + } else { + training_ = state; + } +} + // Sets up the network for training. Initializes weights using weights of // scale `range` picked according to the random number generator `randomizer`. int FullyConnected::InitWeights(float range, TRand* randomizer) { @@ -78,14 +89,14 @@ void FullyConnected::DebugWeights() { // Writes to the given file. Returns false in case of error. bool FullyConnected::Serialize(TFile* fp) const { if (!Network::Serialize(fp)) return false; - if (!weights_.Serialize(training_, fp)) return false; + if (!weights_.Serialize(IsTraining(), fp)) return false; return true; } // Reads from the given file. Returns false in case of error. // If swap is true, assumes a big/little-endian swap is needed. bool FullyConnected::DeSerialize(bool swap, TFile* fp) { - if (!weights_.DeSerialize(training_, swap, fp)) return false; + if (!weights_.DeSerialize(IsTraining(), swap, fp)) return false; return true; } @@ -129,14 +140,14 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input, } ForwardTimeStep(d_input, i_input, t, temp_line); output->WriteTimeStep(t, temp_line); - if (training() && type_ != NT_SOFTMAX) { + if (IsTraining() && type_ != NT_SOFTMAX) { acts_.CopyTimeStepFrom(t, *output, t); } } // Zero all the elements that are in the padding around images that allows // multiple different-sized images to exist in a single array. // acts_ is only used if this is not a softmax op. - if (training() && type_ != NT_SOFTMAX) { + if (IsTraining() && type_ != NT_SOFTMAX) { acts_.ZeroInvalidElements(); } output->ZeroInvalidElements(); @@ -152,7 +163,7 @@ void FullyConnected::SetupForward(const NetworkIO& input, const TransposedArray* input_transpose) { // Softmax output is always float, so save the input type. int_mode_ = input.int_mode(); - if (training()) { + if (IsTraining()) { acts_.Resize(input, no_); // Source_ is a transposed copy of input. It isn't needed if provided. external_source_ = input_transpose; @@ -163,7 +174,7 @@ void FullyConnected::SetupForward(const NetworkIO& input, void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input, int t, double* output_line) { // input is copied to source_ line-by-line for cache coherency. - if (training() && external_source_ == NULL && d_input != NULL) + if (IsTraining() && external_source_ == NULL && d_input != NULL) source_t_.WriteStrided(t, d_input); if (d_input != NULL) weights_.MatrixDotVector(d_input, output_line); diff --git a/lstm/fullyconnected.h b/lstm/fullyconnected.h index d2d2b73ae8..f5a593906d 100644 --- a/lstm/fullyconnected.h +++ b/lstm/fullyconnected.h @@ -61,6 +61,10 @@ class FullyConnected : public Network { type_ = type; } + // Suspends/Enables training by setting the training_ flag. Serialize and + // DeSerialize only operate on the run-time data if state is false. + virtual void SetEnableTraining(TrainingState state); + // Sets up the network for training. Initializes weights using weights of // scale `range` picked according to the random number generator `randomizer`. virtual int InitWeights(float range, TRand* randomizer); diff --git a/lstm/lstm.cpp b/lstm/lstm.cpp index cac5f64c93..3cc2b02aa7 100644 --- a/lstm/lstm.cpp +++ b/lstm/lstm.cpp @@ -102,6 +102,23 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const { return result; } +// Suspends/Enables training by setting the training_ flag. Serialize and +// DeSerialize only operate on the run-time data if state is false. +void LSTM::SetEnableTraining(TrainingState state) { + if (state == TS_RE_ENABLE) { + if (training_ == TS_DISABLED) { + for (int w = 0; w < WT_COUNT; ++w) { + if (w == GFS && !Is2D()) continue; + gate_weights_[w].InitBackward(false); + } + } + training_ = TS_ENABLED; + } else { + training_ = state; + } + if (softmax_ != NULL) softmax_->SetEnableTraining(state); +} + // Sets up the network for training. Initializes weights using weights of // scale `range` picked according to the random number generator `randomizer`. int LSTM::InitWeights(float range, TRand* randomizer) { @@ -148,7 +165,7 @@ bool LSTM::Serialize(TFile* fp) const { if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false; for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; - if (!gate_weights_[w].Serialize(training_, fp)) return false; + if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false; } if (softmax_ != NULL && !softmax_->Serialize(fp)) return false; return true; @@ -169,7 +186,7 @@ bool LSTM::DeSerialize(bool swap, TFile* fp) { is_2d_ = false; for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; - if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false; + if (!gate_weights_[w].DeSerialize(IsTraining(), swap, fp)) return false; if (w == CI) { ns_ = gate_weights_[CI].NumOutputs(); is_2d_ = na_ - nf_ == ni_ + 2 * ns_; @@ -322,7 +339,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input, MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state); // Clip curr_state to a sane range. ClipVector(ns_, -kStateClip, kStateClip, curr_state); - if (training_) { + if (IsTraining()) { // Save the gate node values. node_values_[CI].WriteTimeStep(t, temp_lines[CI]); node_values_[GI].WriteTimeStep(t, temp_lines[GI]); @@ -331,7 +348,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input, if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]); } FuncMultiply(curr_state, temp_lines[GO], ns_, curr_output); - if (training_) state_.WriteTimeStep(t, curr_state); + if (IsTraining()) state_.WriteTimeStep(t, curr_state); if (softmax_ != NULL) { if (input.int_mode()) { int_output->WriteTimeStep(0, curr_output); @@ -697,7 +714,7 @@ void LSTM::PrintDW() { void LSTM::ResizeForward(const NetworkIO& input) { source_.Resize(input, na_); which_fg_.ResizeNoInit(input.Width(), ns_); - if (training_) { + if (IsTraining()) { state_.ResizeFloat(input, ns_); for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; diff --git a/lstm/lstm.h b/lstm/lstm.h index c62a846013..f87fa68118 100644 --- a/lstm/lstm.h +++ b/lstm/lstm.h @@ -69,6 +69,10 @@ class LSTM : public Network { return spec; } + // Suspends/Enables training by setting the training_ flag. Serialize and + // DeSerialize only operate on the run-time data if state is false. + virtual void SetEnableTraining(TrainingState state); + // Sets up the network for training. Initializes weights using weights of // scale `range` picked according to the random number generator `randomizer`. virtual int InitWeights(float range, TRand* randomizer); diff --git a/lstm/lstmrecognizer.cpp b/lstm/lstmrecognizer.cpp index 236455b411..1d4f0f39d5 100644 --- a/lstm/lstmrecognizer.cpp +++ b/lstm/lstmrecognizer.cpp @@ -253,7 +253,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, float label_threshold, float* scale_factor, NetworkIO* inputs, NetworkIO* outputs) { // Maximum width of image to train on. - const int kMaxImageWidth = 2048; + const int kMaxImageWidth = 2560; // This ensures consistent recognition results. SetRandomSeed(); int min_width = network_->XScaleFactor(); @@ -263,7 +263,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, tprintf("Line cannot be recognized!!\n"); return false; } - if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) { + if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) { tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix), pixGetHeight(pix)); pixDestroy(&pix); diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp index 5865bc45b1..42e43162ce 100644 --- a/lstm/lstmtrainer.cpp +++ b/lstm/lstmtrainer.cpp @@ -134,8 +134,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) { // Note: Call before InitNetwork! void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir, int train_flags) { - // Call before InitNetwork. - ASSERT_HOST(network_ == NULL); EmptyConstructor(); training_flags_ = train_flags; ccutil_.unicharset.CopyFrom(unicharset); @@ -150,8 +148,6 @@ void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, // Note: Call before InitNetwork! void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, const UnicharCompress recoder) { - // Call before InitNetwork. - ASSERT_HOST(network_ == NULL); EmptyConstructor(); int flags = TF_COMPRESS_UNICHARSET; training_flags_ = static_cast(flags); @@ -219,6 +215,30 @@ int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) { #endif } +// Resets all the iteration counters for fine tuning or traininng a head, +// where we want the error reporting to reset. +void LSTMTrainer::InitIterations() { + sample_iteration_ = 0; + training_iteration_ = 0; + learning_iteration_ = 0; + prev_sample_iteration_ = 0; + best_error_rate_ = 100.0; + best_iteration_ = 0; + worst_error_rate_ = 0.0; + worst_iteration_ = 0; + stall_iteration_ = kMinStallIterations; + improvement_steps_ = kMinStallIterations; + perfect_delay_ = 0; + last_perfect_training_iteration_ = 0; + for (int i = 0; i < ET_COUNT; ++i) { + best_error_rates_[i] = 100.0; + worst_error_rates_[i] = 0.0; + error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0); + error_rates_[i] = 100.0; + } + error_rate_of_last_saved_best_ = kMinStartedErrorRate; +} + // If the training sample is usable, grid searches for the optimal // dict_ratio/cert_offset, and returns the results in a string of space- // separated triplets of ratio,offset=worderr. @@ -460,8 +480,15 @@ bool LSTMTrainer::Serialize(TFile* fp) const { // If swap is true, assumes a big/little-endian swap is needed. bool LSTMTrainer::DeSerialize(bool swap, TFile* fp) { if (!LSTMRecognizer::DeSerialize(swap, fp)) return false; - if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) - return false; + if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) { + // Special case. If we successfully decoded the recognizer, but fail here + // then it means we were just given a recognizer, so issue a warning and + // allow it. + tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n"); + learning_iteration_ = 0; + network_->SetEnableTraining(TS_RE_ENABLE); + return true; + } if (fp->FRead(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) != 1) return false; @@ -629,7 +656,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, SaveTrainingDump(LIGHT, this, &orig_trainer); for (int i = 0; i < num_layers; ++i) { Network* layer = GetLayer(layers[i]); - num_weights[i] = layer->training() ? layer->num_weights() : 0; + num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0; } int iteration = sample_iteration(); for (int s = 0; s < num_samples; ++s) { @@ -773,7 +800,7 @@ Trainability LSTMTrainer::TrainOnLine(const ImageData* trainingdata, training_iteration() % debug_interval_ == 0; // Run backprop on the output. NetworkIO bp_deltas; - if (network_->training() && + if (network_->IsTraining() && (trainable != PERFECT || training_iteration() > last_perfect_training_iteration_ + perfect_delay_)) { @@ -827,6 +854,7 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata, return UNENCODABLE; } targets->Resize(*fwd_outputs, network_->NumOutputs()); + double text_error = 100.0; LossType loss_type = OutputLossType(); if (loss_type == LT_SOFTMAX) { if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) { @@ -900,9 +928,9 @@ bool LSTMTrainer::ReadSizedTrainingDump(const char* data, int size) { void LSTMTrainer::SaveRecognitionDump(GenericVector* data) const { TFile fp; fp.OpenWrite(data); - network_->SetEnableTraining(false); + network_->SetEnableTraining(TS_TEMP_DISABLE); ASSERT_HOST(LSTMRecognizer::Serialize(&fp)); - network_->SetEnableTraining(true); + network_->SetEnableTraining(TS_RE_ENABLE); } // Reads and returns a previously saved recognizer from memory. @@ -942,25 +970,7 @@ void LSTMTrainer::EmptyConstructor() { serialize_amount_ = FULL; training_stage_ = 0; num_training_stages_ = 2; - prev_sample_iteration_ = 0; - best_error_rate_ = 100.0; - best_iteration_ = 0; - worst_error_rate_ = 0.0; - worst_iteration_ = 0; - stall_iteration_ = kMinStallIterations; - learning_iteration_ = 0; - improvement_steps_ = kMinStallIterations; - perfect_delay_ = 0; - last_perfect_training_iteration_ = 0; - for (int i = 0; i < ET_COUNT; ++i) { - best_error_rates_[i] = 100.0; - worst_error_rates_[i] = 0.0; - error_buffers_[i].init_to_size(kRollingBufferSize_, 0.0); - error_rates_[i] = 100.0; - } - sample_iteration_ = 0; - training_iteration_ = 0; - error_rate_of_last_saved_best_ = kMinStartedErrorRate; + InitIterations(); } // Sets the unicharset properties using the given script_dir as a source of diff --git a/lstm/lstmtrainer.h b/lstm/lstmtrainer.h index e6a7c43f2e..918c0381be 100644 --- a/lstm/lstmtrainer.h +++ b/lstm/lstmtrainer.h @@ -127,6 +127,9 @@ class LSTMTrainer : public LSTMRecognizer { // Returns the global step of TensorFlow graph or 0 if failed. // Building a compatible TF graph: See tfnetwork.proto. int InitTensorFlowNetwork(const std::string& tf_proto); + // Resets all the iteration counters for fine tuning or training a head, + // where we want the error reporting to reset. + void InitIterations(); // Accessors. double ActivationError() const { diff --git a/lstm/network.cpp b/lstm/network.cpp index 3120a3f70a..795d4a5b7c 100644 --- a/lstm/network.cpp +++ b/lstm/network.cpp @@ -69,23 +69,47 @@ char const* const Network::kTypeNames[NT_COUNT] = { }; Network::Network() - : type_(NT_NONE), training_(true), needs_to_backprop_(true), - network_flags_(0), ni_(0), no_(0), num_weights_(0), - forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) { -} + : type_(NT_NONE), + training_(TS_ENABLED), + needs_to_backprop_(true), + network_flags_(0), + ni_(0), + no_(0), + num_weights_(0), + forward_win_(NULL), + backward_win_(NULL), + randomizer_(NULL) {} Network::Network(NetworkType type, const STRING& name, int ni, int no) - : type_(type), training_(true), needs_to_backprop_(true), - network_flags_(0), ni_(ni), no_(no), num_weights_(0), - name_(name), forward_win_(NULL), backward_win_(NULL), randomizer_(NULL) { -} + : type_(type), + training_(TS_ENABLED), + needs_to_backprop_(true), + network_flags_(0), + ni_(ni), + no_(no), + num_weights_(0), + name_(name), + forward_win_(NULL), + backward_win_(NULL), + randomizer_(NULL) {} Network::~Network() { } -// Ends training by setting the training_ flag to false. Serialize and -// DeSerialize will now only operate on the run-time data. -void Network::SetEnableTraining(bool state) { - training_ = state; +// Suspends/Enables/Permanently disables training by setting the training_ +// flag. Serialize and DeSerialize only operate on the run-time data if state +// is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will +// temporarily disable layers in state TS_ENABLED, allowing a trainer to +// serialize as if it were a recognizer. +// TS_RE_ENABLE will re-enable layers that were previously in any disabled +// state. If in TS_TEMP_DISABLE then the flag is just changed, but if in +// TS_DISABLED, the deltas in the weight matrices are reinitialized so that a +// recognizer can be converted back to a trainer. +void Network::SetEnableTraining(TrainingState state) { + if (state == TS_RE_ENABLE) { + training_ = TS_ENABLED; + } else { + training_ = state; + } } // Sets flags that control the action of the network. See NetworkFlags enum @@ -152,7 +176,7 @@ bool Network::DeSerialize(bool swap, TFile* fp) { } type_ = static_cast(data); if (fp->FRead(&data, sizeof(data), 1) != 1) return false; - training_ = data != 0; + training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED; if (fp->FRead(&data, sizeof(data), 1) != 1) return false; needs_to_backprop_ = data != 0; if (fp->FRead(&network_flags_, sizeof(network_flags_), 1) != 1) return false; diff --git a/lstm/network.h b/lstm/network.h index edd04b4f6d..db38b1821e 100644 --- a/lstm/network.h +++ b/lstm/network.h @@ -88,6 +88,16 @@ enum NetworkFlags { NF_ADA_GRAD = 128, // Weight-specific learning rate. }; +// State of training and desired state used in SetEnableTraining. +enum TrainingState { + // Valid states of training_. + TS_DISABLED, // Disabled permanently. + TS_ENABLED, // Enabled for backprop and to write a training dump. + TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump. + // Valid only for SetEnableTraining. + TS_RE_ENABLE, // Re-Enable whatever the current state. +}; + // Base class for network types. Not quite an abstract base class, but almost. // Most of the time no isolated Network exists, except prior to // deserialization. @@ -101,9 +111,7 @@ class Network { NetworkType type() const { return type_; } - bool training() const { - return training_; - } + bool IsTraining() const { return training_ == TS_ENABLED; } bool needs_to_backprop() const { return needs_to_backprop_; } @@ -142,9 +150,16 @@ class Network { // multiple sub-networks that can have their own learning rate. virtual bool IsPlumbingType() const { return false; } - // Suspends/Enables training by setting the training_ flag. Serialize and - // DeSerialize only operate on the run-time data if state is false. - virtual void SetEnableTraining(bool state); + // Suspends/Enables/Permanently disables training by setting the training_ + // flag. Serialize and DeSerialize only operate on the run-time data if state + // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will + // temporarily disable layers in state TS_ENABLED, allowing a trainer to + // serialize as if it were a recognizer. + // TS_RE_ENABLE will re-enable layers that were previously in any disabled + // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in + // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a + // recognizer can be converted back to a trainer. + virtual void SetEnableTraining(TrainingState state); // Sets flags that control the action of the network. See NetworkFlags enum // for bit values. @@ -269,7 +284,7 @@ class Network { protected: NetworkType type_; // Type of the derived network class. - bool training_; // Are we currently training? + TrainingState training_; // Are we currently training? bool needs_to_backprop_; // This network needs to output back_deltas. inT32 network_flags_; // Behavior control flags in NetworkFlags. inT32 ni_; // Number of input values. diff --git a/lstm/parallel.cpp b/lstm/parallel.cpp index 516fe33a1a..2c4d5fb6b2 100644 --- a/lstm/parallel.cpp +++ b/lstm/parallel.cpp @@ -83,7 +83,7 @@ void Parallel::Forward(bool debug, const NetworkIO& input, // Source for divided replicated. NetworkScratch::IO source_part; TransposedArray* src_transpose = NULL; - if (training() && type_ == NT_REPLICATED) { + if (IsTraining() && type_ == NT_REPLICATED) { // Make a transposed copy of the input. input.Transpose(&transposed_input_); src_transpose = &transposed_input_; diff --git a/lstm/plumbing.cpp b/lstm/plumbing.cpp index 01abdb91f9..bfb582541f 100644 --- a/lstm/plumbing.cpp +++ b/lstm/plumbing.cpp @@ -31,7 +31,7 @@ Plumbing::~Plumbing() { // Suspends/Enables training by setting the training_ flag. Serialize and // DeSerialize only operate on the run-time data if state is false. -void Plumbing::SetEnableTraining(bool state) { +void Plumbing::SetEnableTraining(TrainingState state) { Network::SetEnableTraining(state); for (int i = 0; i < stack_.size(); ++i) stack_[i]->SetEnableTraining(state); @@ -91,13 +91,17 @@ void Plumbing::AddToStack(Network* network) { // Sets needs_to_backprop_ to needs_backprop and calls on sub-network // according to needs_backprop || any weights in this network. bool Plumbing::SetupNeedsBackprop(bool needs_backprop) { - needs_to_backprop_ = needs_backprop; - bool retval = needs_backprop; - for (int i = 0; i < stack_.size(); ++i) { - if (stack_[i]->SetupNeedsBackprop(needs_backprop)) - retval = true; + if (IsTraining()) { + needs_to_backprop_ = needs_backprop; + bool retval = needs_backprop; + for (int i = 0; i < stack_.size(); ++i) { + if (stack_[i]->SetupNeedsBackprop(needs_backprop)) retval = true; + } + return retval; } - return retval; + // Frozen networks don't do backprop. + needs_to_backprop_ = false; + return false; } // Returns an integer reduction factor that the network applies to the @@ -212,8 +216,9 @@ void Plumbing::Update(float learning_rate, float momentum, int num_samples) { else learning_rates_.push_back(learning_rate); } - if (stack_[i]->training()) + if (stack_[i]->IsTraining()) { stack_[i]->Update(learning_rate, momentum, num_samples); + } } } diff --git a/lstm/plumbing.h b/lstm/plumbing.h index 1a2185c333..bda855e09f 100644 --- a/lstm/plumbing.h +++ b/lstm/plumbing.h @@ -45,7 +45,7 @@ class Plumbing : public Network { // Suspends/Enables training by setting the training_ flag. Serialize and // DeSerialize only operate on the run-time data if state is false. - virtual void SetEnableTraining(bool state); + virtual void SetEnableTraining(TrainingState state); // Sets flags that control the action of the network. See NetworkFlags enum // for bit values. diff --git a/lstm/series.cpp b/lstm/series.cpp index 96ff4704db..83d26cbf44 100644 --- a/lstm/series.cpp +++ b/lstm/series.cpp @@ -116,7 +116,7 @@ void Series::Forward(bool debug, const NetworkIO& input, bool Series::Backward(bool debug, const NetworkIO& fwd_deltas, NetworkScratch* scratch, NetworkIO* back_deltas) { - if (!training()) return false; + if (!IsTraining()) return false; int stack_size = stack_.size(); ASSERT_HOST(stack_size > 1); // Revolving intermediate buffers. @@ -124,16 +124,16 @@ bool Series::Backward(bool debug, const NetworkIO& fwd_deltas, NetworkScratch::IO buffer2(fwd_deltas, scratch); // Run each network in reverse order, giving the back_deltas output of n as // the fwd_deltas input to n-1, with the 0 network providing the real output. - if (!stack_.back()->training() || + if (!stack_.back()->IsTraining() || !stack_.back()->Backward(debug, fwd_deltas, scratch, buffer1)) return false; for (int i = stack_size - 2; i >= 0; i -= 2) { - if (!stack_[i]->training() || + if (!stack_[i]->IsTraining() || !stack_[i]->Backward(debug, *buffer1, scratch, i > 0 ? buffer2 : back_deltas)) return false; if (i == 0) return needs_to_backprop_; - if (!stack_[i - 1]->training() || + if (!stack_[i - 1]->IsTraining() || !stack_[i - 1]->Backward(debug, *buffer2, scratch, i > 1 ? buffer1 : back_deltas)) return false; diff --git a/lstm/stridemap.h b/lstm/stridemap.h index cd7d17a7d8..2dd9e49b7d 100644 --- a/lstm/stridemap.h +++ b/lstm/stridemap.h @@ -69,8 +69,8 @@ class StrideMap { bool IsValid() const; // Returns true if the index of the given dimension is the last. bool IsLast(FlexDimensions dimension) const; - // Given that the dimensions up to and including dim-1 are valid, returns the - // maximum index for dimension dim. + // Given that the dimensions up to and including dim-1 are valid, returns + // the maximum index for dimension dim. int MaxIndexOfDim(FlexDimensions dim) const; // Adds the given offset to the given dimension. Returns true if the result // makes a valid index. diff --git a/lstm/weightmatrix.cpp b/lstm/weightmatrix.cpp index 3ee95961d1..e29eccb60b 100644 --- a/lstm/weightmatrix.cpp +++ b/lstm/weightmatrix.cpp @@ -98,8 +98,6 @@ void TransposedArray::Transpose(const GENERIC_2D_ARRAY& input) { int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand* randomizer) { int_mode_ = false; - use_ada_grad_ = ada_grad; - if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0); wf_.Resize(no, ni, 0.0); if (randomizer != NULL) { for (int i = 0; i < no; ++i) { @@ -108,7 +106,7 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, } } } - InitBackward(); + InitBackward(ada_grad); return ni * no; } @@ -144,12 +142,14 @@ void WeightMatrix::ConvertToInt() { // Allocates any needed memory for running Backward, and zeroes the deltas, // thus eliminating any existing momentum. -void WeightMatrix::InitBackward() { +void WeightMatrix::InitBackward(bool ada_grad) { int no = int_mode_ ? wi_.dim1() : wf_.dim1(); int ni = int_mode_ ? wi_.dim2() : wf_.dim2(); + use_ada_grad_ = ada_grad; dw_.Resize(no, ni, 0.0); updates_.Resize(no, ni, 0.0); wf_t_.Transpose(wf_); + if (use_ada_grad_) dw_sq_sum_.Resize(no, ni, 0.0); } // Flag on mode to indicate that this weightmatrix uses inT8. @@ -193,7 +193,7 @@ bool WeightMatrix::DeSerialize(bool training, bool swap, TFile* fp) { } else { if (!wf_.DeSerialize(swap, fp)) return false; if (training) { - InitBackward(); + InitBackward(use_ada_grad_); if (!updates_.DeSerialize(swap, fp)) return false; if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(swap, fp)) return false; } @@ -216,7 +216,7 @@ bool WeightMatrix::DeSerializeOld(bool training, bool swap, TFile* fp) { FloatToDouble(float_array, &wf_); } if (training) { - InitBackward(); + InitBackward(use_ada_grad_); if (!float_array.DeSerialize(swap, fp)) return false; FloatToDouble(float_array, &updates_); // Errs was only used in int training, which is now dead. diff --git a/lstm/weightmatrix.h b/lstm/weightmatrix.h index 24fd5c10e3..635c66188c 100644 --- a/lstm/weightmatrix.h +++ b/lstm/weightmatrix.h @@ -92,7 +92,7 @@ class WeightMatrix { // Allocates any needed memory for running Backward, and zeroes the deltas, // thus eliminating any existing momentum. - void InitBackward(); + void InitBackward(bool ada_grad); // Writes to the given file. Returns false in case of error. bool Serialize(bool training, TFile* fp) const; diff --git a/training/Makefile.am b/training/Makefile.am index da505be724..5440bfed7c 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -27,7 +27,7 @@ endif noinst_HEADERS = \ boxchar.h commandlineflags.h commontraining.h degradeimage.h \ - fileio.h icuerrorcode.h ligature_table.h normstrngs.h \ + fileio.h icuerrorcode.h ligature_table.h lstmtester.h normstrngs.h \ mergenf.h pango_font_info.h stringrenderer.h \ tessopt.h tlog.h unicharset_training_utils.h util.h @@ -39,14 +39,14 @@ libtesseract_training_la_LIBADD = \ libtesseract_training_la_SOURCES = \ boxchar.cpp commandlineflags.cpp commontraining.cpp degradeimage.cpp \ - fileio.cpp ligature_table.cpp normstrngs.cpp pango_font_info.cpp \ + fileio.cpp ligature_table.cpp lstmtester.cpp normstrngs.cpp pango_font_info.cpp \ stringrenderer.cpp tlog.cpp unicharset_training_utils.cpp libtesseract_tessopt_la_SOURCES = \ tessopt.cpp bin_PROGRAMS = ambiguous_words classifier_tester cntraining combine_tessdata \ - dawg2wordlist lstmtraining mftraining set_unicharset_properties shapeclustering \ + dawg2wordlist lstmeval lstmtraining mftraining set_unicharset_properties shapeclustering \ text2image unicharset_extractor wordlist2dawg ambiguous_words_SOURCES = ambiguous_words.cpp @@ -163,6 +163,33 @@ dawg2wordlist_LDADD += \ ../api/libtesseract.la endif +lstmeval_SOURCES = lstmeval.cpp +#lstmeval_LDFLAGS = -static +lstmeval_LDADD = \ + libtesseract_training.la \ + libtesseract_tessopt.la \ + $(libicu) +if USING_MULTIPLELIBS +lstmeval_LDADD += \ + ../textord/libtesseract_textord.la \ + ../classify/libtesseract_classify.la \ + ../dict/libtesseract_dict.la \ + ../arch/libtesseract_avx.la \ + ../arch/libtesseract_sse.la \ + ../lstm/libtesseract_lstm.la \ + ../ccstruct/libtesseract_ccstruct.la \ + ../cutil/libtesseract_cutil.la \ + ../viewer/libtesseract_viewer.la \ + ../ccmain/libtesseract_main.la \ + ../cube/libtesseract_cube.la \ + ../neural_networks/runtime/libtesseract_neural.la \ + ../wordrec/libtesseract_wordrec.la \ + ../ccutil/libtesseract_ccutil.la +else +lstmeval_LDADD += \ + ../api/libtesseract.la +endif + lstmtraining_SOURCES = lstmtraining.cpp #lstmtraining_LDFLAGS = -static lstmtraining_LDADD = \ diff --git a/training/lstmeval.cpp b/training/lstmeval.cpp new file mode 100644 index 0000000000..aa990e2325 --- /dev/null +++ b/training/lstmeval.cpp @@ -0,0 +1,58 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmeval.cpp +// Description: Evaluation program for LSTM-based networks. +// Author: Ray Smith +// Created: Wed Nov 23 12:20:06 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef USE_STD_NAMESPACE +#include "base/commandlineflags.h" +#endif +#include "commontraining.h" +#include "genericvector.h" +#include "lstmtester.h" +#include "strngs.h" +#include "tprintf.h" + +STRING_PARAM_FLAG(model, "", "Name of model file (training or recognition)"); +STRING_PARAM_FLAG(eval_listfile, "", + "File listing sample files in lstmf training format."); +INT_PARAM_FLAG(max_image_MB, 2000, "Max memory to use for images."); + +int main(int argc, char **argv) { + ParseArguments(&argc, &argv); + if (FLAGS_model.empty()) { + tprintf("Must provide a --model!\n"); + return 1; + } + if (FLAGS_eval_listfile.empty()) { + tprintf("Must provide a --eval_listfile!\n"); + return 1; + } + GenericVector model_data; + if (!tesseract::LoadDataFromFile(FLAGS_model.c_str(), &model_data)) { + tprintf("Failed to load model from: %s\n", FLAGS_eval_listfile.c_str()); + return 1; + } + tesseract::LSTMTester tester(static_cast(FLAGS_max_image_MB) * + 1048576); + if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) { + tprintf("Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str()); + return 1; + } + double errs = 0.0; + STRING result = tester.RunEvalSync(0, &errs, model_data, 0); + tprintf("%s\n", result.string()); + return 0; +} /* main */ diff --git a/training/lstmtester.cpp b/training/lstmtester.cpp new file mode 100644 index 0000000000..df37ebd7ea --- /dev/null +++ b/training/lstmtester.cpp @@ -0,0 +1,146 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmtester.cpp +// Description: Top-level line evaluation class for LSTM-based networks. +// Author: Ray Smith +// Created: Wed Nov 23 11:18:06 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "lstmtester.h" +#include "genericvector.h" + +namespace tesseract { + +LSTMTester::LSTMTester(inT64 max_memory) + : test_data_(max_memory), total_pages_(0), async_running_(false) {} + +// Loads a set of lstmf files that were created using the lstm.train config to +// tesseract into memory ready for testing. Returns false if nothing was +// loaded. The arg is a filename of a file that lists the filenames. +bool LSTMTester::LoadAllEvalData(const STRING& filenames_file) { + GenericVector filenames; + if (!LoadFileLinesToStrings(filenames_file, &filenames)) { + tprintf("Failed to load list of eval filenames from %s\n", + filenames_file.string()); + return false; + } + return LoadAllEvalData(filenames); +} + +// Loads a set of lstmf files that were created using the lstm.train config to +// tesseract into memory ready for testing. Returns false if nothing was +// loaded. +bool LSTMTester::LoadAllEvalData(const GenericVector& filenames) { + test_data_.Clear(); + bool result = + test_data_.LoadDocuments(filenames, "eng", CS_SEQUENTIAL, nullptr); + total_pages_ = test_data_.TotalPages(); + return result; +} + +// Runs an evaluation asynchronously on the stored data and returns a string +// describing the results of the previous test. +STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors, + const GenericVector& model_data, + int training_stage) { + STRING result; + if (total_pages_ == 0) { + result.add_str_int("No test data at iteration", iteration); + return result; + } + if (!LockIfNotRunning()) { + result.add_str_int("Previous test incomplete, skipping test at iteration", + iteration); + return result; + } + // Save the args. + STRING prev_result = test_result_; + test_result_ = ""; + if (training_errors != nullptr) { + test_iteration_ = iteration; + test_training_errors_ = training_errors; + test_model_data_ = model_data; + test_training_stage_ = training_stage; + SVSync::StartThread(&LSTMTester::ThreadFunc, this); + } else { + UnlockRunning(); + } + return prev_result; +} + +// Runs an evaluation synchronously on the stored data and returns a string +// describing the results. +STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors, + const GenericVector& model_data, + int training_stage) { + LSTMTrainer trainer; + if (!trainer.ReadTrainingDump(model_data, &trainer)) { + return "Deserialize failed"; + } + int eval_iteration = 0; + double char_error = 0.0; + double word_error = 0.0; + int error_count = 0; + while (error_count < total_pages_) { + const ImageData* trainingdata = test_data_.GetPageBySerial(eval_iteration); + trainer.SetIteration(++eval_iteration); + NetworkIO fwd_outputs, targets; + if (trainer.PrepareForBackward(trainingdata, &fwd_outputs, &targets) != + UNENCODABLE) { + char_error += trainer.NewSingleError(tesseract::ET_CHAR_ERROR); + word_error += trainer.NewSingleError(tesseract::ET_WORD_RECERR); + ++error_count; + } + } + char_error *= 100.0 / total_pages_; + word_error *= 100.0 / total_pages_; + STRING result; + result.add_str_int("At iteration ", iteration); + result.add_str_int(", stage ", training_stage); + result.add_str_double(", Eval Char error rate=", char_error); + result.add_str_double(", Word error rate=", word_error); + return result; +} + +// Static helper thread function for RunEvalAsync, with a specific signature +// required by SVSync::StartThread. Actually a member function pretending to +// be static, its arg is a this pointer that it will cast back to LSTMTester* +// to call RunEvalSync using the stored args that RunEvalAsync saves in *this. +// LockIfNotRunning must have returned true before calling ThreadFunc, and +// it will call UnlockRunning to release the lock after RunEvalSync completes. +/* static */ +void* LSTMTester::ThreadFunc(void* lstmtester_void) { + LSTMTester* lstmtester = reinterpret_cast(lstmtester_void); + lstmtester->test_result_ = lstmtester->RunEvalSync( + lstmtester->test_iteration_, lstmtester->test_training_errors_, + lstmtester->test_model_data_, lstmtester->test_training_stage_); + lstmtester->UnlockRunning(); + return lstmtester_void; +} + +// Returns true if there is currently nothing running, and takes the lock +// if there is nothing running. +bool LSTMTester::LockIfNotRunning() { + SVAutoLock lock(&running_mutex_); + if (async_running_) return false; + async_running_ = true; + return true; +} + +// Releases the running lock. +void LSTMTester::UnlockRunning() { + SVAutoLock lock(&running_mutex_); + async_running_ = false; +} + +} // namespace tesseract diff --git a/training/lstmtester.h b/training/lstmtester.h new file mode 100644 index 0000000000..3b4cb05e78 --- /dev/null +++ b/training/lstmtester.h @@ -0,0 +1,94 @@ +/////////////////////////////////////////////////////////////////////// +// File: lstmtester.h +// Description: Top-level line evaluation class for LSTM-based networks. +// Author: Ray Smith +// Created: Wed Nov 23 11:05:06 PST 2016 +// +// (C) Copyright 2016, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#ifndef TESSERACT_TRAINING_LSTMTESTER_H_ +#define TESSERACT_TRAINING_LSTMTESTER_H_ + +#include "genericvector.h" +#include "lstmtrainer.h" +#include "strngs.h" +#include "svutil.h" + +namespace tesseract { + +class LSTMTester { + public: + LSTMTester(inT64 max_memory); + + // Loads a set of lstmf files that were created using the lstm.train config to + // tesseract into memory ready for testing. Returns false if nothing was + // loaded. The arg is a filename of a file that lists the filenames, with one + // name per line. Conveniently, tesstrain.sh generates such a file, along + // with the files themselves. + bool LoadAllEvalData(const STRING& filenames_file); + // Loads a set of lstmf files that were created using the lstm.train config to + // tesseract into memory ready for testing. Returns false if nothing was + // loaded. + bool LoadAllEvalData(const GenericVector& filenames); + + // Runs an evaluation asynchronously on the stored eval data and returns a + // string describing the results of the previous test. Args match TestCallback + // declared in lstmtrainer.h: + // iteration: Current learning iteration number. + // training_errors: If not null, is an array of size ET_COUNT, indexed by + // the ErrorTypes enum and indicates the current errors measured by the + // trainer, and this is a serious request to run an evaluation. If null, + // then the caller is just polling for the results of the previous eval. + // model_data: is the model to evaluate, which should be a serialized + // LSTMTrainer. + // training_stage: an arbitrary number on the progress of training. + STRING RunEvalAsync(int iteration, const double* training_errors, + const GenericVector& model_data, + int training_stage); + // Runs an evaluation synchronously on the stored eval data and returns a + // string describing the results. Args as RunEvalAsync. + STRING RunEvalSync(int iteration, const double* training_errors, + const GenericVector& model_data, int training_stage); + + private: + // Static helper thread function for RunEvalAsync, with a specific signature + // required by SVSync::StartThread. Actually a member function pretending to + // be static, its arg is a this pointer that it will cast back to LSTMTester* + // to call RunEvalSync using the stored args that RunEvalAsync saves in *this. + // LockIfNotRunning must have returned true before calling ThreadFunc, and + // it will call UnlockRunning to release the lock after RunEvalSync completes. + static void* ThreadFunc(void* lstmtester_void); + // Returns true if there is currently nothing running, and takes the lock + // if there is nothing running. + bool LockIfNotRunning(); + // Releases the running lock. + void UnlockRunning(); + + // The data to test with. + DocumentCache test_data_; + int total_pages_; + // Flag that indicates an asynchronous test is currently running. + // Protected by running_mutex_. + bool async_running_; + SVMutex running_mutex_; + // Stored copies of the args for use while running asynchronously. + int test_iteration_; + const double* test_training_errors_; + GenericVector test_model_data_; + int test_training_stage_; + STRING test_result_; +}; + +} // namespace tesseract + +#endif // TESSERACT_TRAINING_LSTMTESTER_H_ diff --git a/training/lstmtraining.cpp b/training/lstmtraining.cpp index f4d46cf9c4..e8551c498d 100644 --- a/training/lstmtraining.cpp +++ b/training/lstmtraining.cpp @@ -20,6 +20,7 @@ #include "base/commandlineflags.h" #endif #include "commontraining.h" +#include "lstmtester.h" #include "lstmtrainer.h" #include "params.h" #include "strngs.h" @@ -27,8 +28,8 @@ #include "unicharset_training_utils.h" INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment."); -STRING_PARAM_FLAG(net_spec, "[I1,48Lt1,100O]", "Network specification"); -INT_PARAM_FLAG(train_mode, 64, "Controls gross training behavior."); +STRING_PARAM_FLAG(net_spec, "", "Network specification"); +INT_PARAM_FLAG(train_mode, 80, "Controls gross training behavior."); INT_PARAM_FLAG(net_mode, 192, "Controls network behavior."); INT_PARAM_FLAG(perfect_sample_delay, 4, "How many imperfect samples between perfect ones."); @@ -42,6 +43,10 @@ STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models"); STRING_PARAM_FLAG(script_dir, "", "Required to set unicharset properties or" " use unicharset compression."); +STRING_PARAM_FLAG(train_listfile, "", + "File listing training files in lstmf training format."); +STRING_PARAM_FLAG(eval_listfile, "", + "File listing eval files in lstmf training format."); BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model."); INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to" @@ -106,9 +111,16 @@ int main(int argc, char **argv) { } // Get the list of files to process. + if (FLAGS_train_listfile.empty()) { + tprintf("Must supply a list of training filenames! --train_listfile\n"); + return 1; + } GenericVector filenames; - for (int arg = 1; arg < argc; ++arg) { - filenames.push_back(STRING(argv[arg])); + if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(), + &filenames)) { + tprintf("Failed to load list of training filenames from %s\n", + FLAGS_train_listfile.c_str()); + return 1; } UNICHARSET unicharset; @@ -125,6 +137,7 @@ int main(int argc, char **argv) { return 1; } tprintf("Continuing from %s\n", FLAGS_continue_from.c_str()); + trainer.InitIterations(); } if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) { // We need a unicharset to start from scratch or append. @@ -164,6 +177,18 @@ int main(int argc, char **argv) { char* best_model_dump = NULL; size_t best_model_size = 0; STRING best_model_name; + tesseract::LSTMTester tester(static_cast(FLAGS_max_image_MB) * + 1048576); + tesseract::TestCallback tester_callback = nullptr; + if (!FLAGS_eval_listfile.empty()) { + if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) { + tprintf("Failed to load eval data from: %s\n", + FLAGS_eval_listfile.c_str()); + return 1; + } + tester_callback = + NewPermanentTessCallback(&tester, &tesseract::LSTMTester::RunEvalAsync); + } do { // Train a few. int iteration = trainer.training_iteration(); @@ -173,11 +198,12 @@ int main(int argc, char **argv) { trainer.TrainOnLine(&trainer, false); } STRING log_str; - trainer.MaintainCheckpoints(NULL, &log_str); + trainer.MaintainCheckpoints(tester_callback, &log_str); tprintf("%s\n", log_str.string()); } while (trainer.best_error_rate() > FLAGS_target_error_rate && (trainer.training_iteration() < FLAGS_max_iterations || FLAGS_max_iterations == 0)); + delete tester_callback; tprintf("Finished! Error rate = %g\n", trainer.best_error_rate()); return 0; } /* main */ diff --git a/training/tesstrain.sh b/training/tesstrain.sh index 231b5360fb..e01f99fdd1 100755 --- a/training/tesstrain.sh +++ b/training/tesstrain.sh @@ -23,6 +23,7 @@ # --langdata_dir DATADIR # Path to tesseract/training/langdata directory. # --output_dir OUTPUTDIR # Location of output traineddata file. # --overwrite # Safe to overwrite files in output_dir. +# --linedata_only # Only generate training data for lstmtraining. # --run_shape_clustering # Run shape clustering (use for Indic langs). # --exposures EXPOSURES # A list of exposure levels to use (e.g. "-1 0 1"). # @@ -60,13 +61,18 @@ initialize_fontconfig phase_I_generate_image 8 phase_UP_generate_unicharset phase_D_generate_dawg -phase_E_extract_features "box.train" 8 -phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto" -if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then - phase_S_cluster_shapes +if (( ${LINEDATA} )); then + phase_E_extract_features "lstm.train" 8 "lstmf" + make__lstmdata +else + phase_E_extract_features "box.train" 8 "tr" + phase_C_cluster_prototypes "${TRAINING_DIR}/${LANG_CODE}.normproto" + if [[ "${ENABLE_SHAPE_CLUSTERING}" == "y" ]]; then + phase_S_cluster_shapes + fi + phase_M_cluster_microfeatures + phase_B_generate_ambiguities + make__traineddata fi -phase_M_cluster_microfeatures -phase_B_generate_ambiguities -make__traineddata tlog "\nCompleted training for language '${LANG_CODE}'\n" diff --git a/training/tesstrain_utils.sh b/training/tesstrain_utils.sh index 906a20ac4f..e6ec80e5ee 100755 --- a/training/tesstrain_utils.sh +++ b/training/tesstrain_utils.sh @@ -23,6 +23,7 @@ else fi OUTPUT_DIR="/tmp/tesstrain/tessdata" OVERWRITE=0 +LINEDATA=0 RUN_SHAPE_CLUSTERING=0 EXTRACT_FONT_PROPERTIES=1 WORKSPACE_DIR=`mktemp -d` @@ -90,8 +91,8 @@ parse_flags() { --) break;; --fontlist) - fn=0 - FONTS="" + fn=0 + FONTS="" while test $j -lt ${#ARGV[@]}; do test -z "${ARGV[$j]}" && break test `echo ${ARGV[$j]} | cut -c -2` = "--" && break @@ -124,6 +125,8 @@ parse_flags() { i=$j ;; --overwrite) OVERWRITE=1 ;; + --linedata_only) + LINEDATA=1 ;; --extract_font_properties) EXTRACT_FONT_PROPERTIES=1 ;; --noextract_font_properties) @@ -368,10 +371,11 @@ phase_D_generate_dawg() { phase_E_extract_features() { local box_config=$1 local par_factor=$2 + local ext=$3 if [[ -z ${par_factor} || ${par_factor} -le 0 ]]; then par_factor=1 fi - tlog "\n=== Phase E: Extracting features ===" + tlog "\n=== Phase E: Generating ${ext} files ===" local img_files="" for exposure in ${EXPOSURES}; do @@ -401,7 +405,7 @@ phase_E_extract_features() { export TESSDATA_PREFIX=${OLD_TESSDATA_PREFIX} # Check that all the output files were produced. for img_file in ${img_files}; do - check_file_readable ${img_file%.*}.tr + check_file_readable "${img_file%.*}.${ext}" done } @@ -484,6 +488,39 @@ phase_B_generate_ambiguities() { # TODO: Add support for generating ambiguities automatically. } +make__lstmdata() { + tlog "\n=== Constructing LSTM training data ===" + local lang_prefix=${LANGDATA_ROOT}/${LANG_CODE}/${LANG_CODE} + if [[ ! -d ${OUTPUT_DIR} ]]; then + tlog "Creating new directory ${OUTPUT_DIR}" + mkdir -p ${OUTPUT_DIR} + fi + + # Copy available files for this language from the langdata dir. + if [[ -r ${lang_prefix}.config ]]; then + tlog "Copying ${lang_prefix}.config to ${OUTPUT_DIR}" + cp ${lang_prefix}.config ${OUTPUT_DIR} + chmod u+w ${OUTPUT_DIR}/${LANG_CODE}.config + fi + if [[ -r "${TRAINING_DIR}/${LANG_CODE}.unicharset" ]]; then + tlog "Moving ${TRAINING_DIR}/${LANG_CODE}.unicharset to ${OUTPUT_DIR}" + mv "${TRAINING_DIR}/${LANG_CODE}.unicharset" "${OUTPUT_DIR}" + fi + for ext in number-dawg punc-dawg word-dawg; do + local src="${TRAINING_DIR}/${LANG_CODE}.${ext}" + if [[ -r "${src}" ]]; then + dest="${OUTPUT_DIR}/${LANG_CODE}.lstm-${ext}" + tlog "Moving ${src} to ${dest}" + mv "${src}" "${dest}" + fi + done + for f in "${TRAINING_DIR}/${LANG_CODE}".*.lstmf; do + tlog "Moving ${f} to ${OUTPUT_DIR}" + mv "${f}" "${OUTPUT_DIR}" + done + local lstm_list="${OUTPUT_DIR}/${LANG_CODE}.training_files.txt" + ls -1 "${OUTPUT_DIR}"/*.lstmf > "${lstm_list}" +} make__traineddata() { tlog "\n=== Making final traineddata file ===" diff --git a/viewer/svutil.h b/viewer/svutil.h index 667c052083..03d0b9147f 100644 --- a/viewer/svutil.h +++ b/viewer/svutil.h @@ -26,8 +26,8 @@ #ifdef _WIN32 #ifndef __GNUC__ -#include "platform.h" #include +#include "platform.h" #if defined(_MSC_VER) && _MSC_VER < 1900 #define snprintf _snprintf #endif