diff --git a/api/apitypes.h b/api/apitypes.h index 2c0e85c9dd..f2fed7a86c 100644 --- a/api/apitypes.h +++ b/api/apitypes.h @@ -21,6 +21,7 @@ #define TESSERACT_API_APITYPES_H_ #include "publictypes.h" +#include "version.h" // The types used by the API and Page/ResultIterator can be found in: // ccstruct/publictypes.h diff --git a/api/baseapi.h b/api/baseapi.h index 36b7527e79..a15ea30828 100644 --- a/api/baseapi.h +++ b/api/baseapi.h @@ -20,10 +20,6 @@ #ifndef TESSERACT_API_BASEAPI_H_ #define TESSERACT_API_BASEAPI_H_ -#define TESSERACT_VERSION_STR "4.00.00alpha" -#define TESSERACT_VERSION 0x040000 -#define MAKE_VERSION(major, minor, patch) (((major) << 16) | ((minor) << 8) | \ - (patch)) #include // To avoid collision with other typenames include the ABSOLUTE MINIMUM // complexity of includes here. Use forward declarations wherever possible diff --git a/ccmain/tessedit.cpp b/ccmain/tessedit.cpp index e239c464b6..193858fed8 100644 --- a/ccmain/tessedit.cpp +++ b/ccmain/tessedit.cpp @@ -188,10 +188,10 @@ bool Tesseract::init_tesseract_lang_data( #ifndef ANDROID_BUILD if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY || tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) { - if (mgr->GetComponent(TESSDATA_LSTM, &fp)) { + if (mgr->IsComponentAvailable(TESSDATA_LSTM)) { lstm_recognizer_ = new LSTMRecognizer; - ASSERT_HOST(lstm_recognizer_->DeSerialize(&fp)); - if (lstm_use_matrix) lstm_recognizer_->LoadDictionary(language, mgr); + ASSERT_HOST( + lstm_recognizer_->Load(lstm_use_matrix ? language : nullptr, mgr)); } else { tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n"); tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY); diff --git a/ccutil/Makefile.am b/ccutil/Makefile.am index 0b5f39f004..b2baa06f3d 100644 --- a/ccutil/Makefile.am +++ b/ccutil/Makefile.am @@ -15,7 +15,8 @@ endif include_HEADERS = \ basedir.h errcode.h fileerr.h genericvector.h helpers.h host.h memry.h \ ndminx.h params.h ocrclass.h platform.h serialis.h strngs.h \ - tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h + tesscallback.h unichar.h unicharcompress.h unicharmap.h unicharset.h \ + version.h noinst_HEADERS = \ ambigs.h bits16.h bitvector.h ccutil.h clst.h doubleptr.h elst2.h \ diff --git a/ccutil/tessdatamanager.cpp b/ccutil/tessdatamanager.cpp index 048ff15824..c11a2d4fbe 100644 --- a/ccutil/tessdatamanager.cpp +++ b/ccutil/tessdatamanager.cpp @@ -78,6 +78,9 @@ bool TessdataManager::LoadMemBuffer(const char *name, const char *data, if (fp.FRead(&entries_[i][0], 1, entry_size) != entry_size) return false; } } + if (entries_[TESSDATA_VERSION].empty()) { + SetVersionString("Pre-4.0.0"); + } is_loaded_ = true; return true; } @@ -139,6 +142,7 @@ void TessdataManager::Clear() { // Prints a directory of contents. void TessdataManager::Directory() const { + tprintf("Version string:%s\n", VersionString().c_str()); int offset = TESSDATA_NUM_ENTRIES * sizeof(inT64); for (int i = 0; i < TESSDATA_NUM_ENTRIES; ++i) { if (!entries_[i].empty()) { @@ -153,12 +157,32 @@ void TessdataManager::Directory() const { // Returns false in case of failure. bool TessdataManager::GetComponent(TessdataType type, TFile *fp) { if (!is_loaded_ && !Init(data_file_name_.string())) return false; + const TessdataManager *const_this = this; + return const_this->GetComponent(type, fp); +} + +// As non-const version except it can't load the component if not already +// loaded. +bool TessdataManager::GetComponent(TessdataType type, TFile *fp) const { + ASSERT_HOST(is_loaded_); if (entries_[type].empty()) return false; fp->Open(&entries_[type][0], entries_[type].size()); fp->set_swap(swap_); return true; } +// Returns the current version string. +string TessdataManager::VersionString() const { + return string(&entries_[TESSDATA_VERSION][0], + entries_[TESSDATA_VERSION].size()); +} + +// Sets the version string to the given v_str. +void TessdataManager::SetVersionString(const string &v_str) { + entries_[TESSDATA_VERSION].resize_no_init(v_str.size()); + memcpy(&entries_[TESSDATA_VERSION][0], v_str.data(), v_str.size()); +} + bool TessdataManager::CombineDataFiles( const char *language_data_path_prefix, const char *output_filename) { diff --git a/ccutil/tessdatamanager.h b/ccutil/tessdatamanager.h index db9c5583f7..e2754fe2c7 100644 --- a/ccutil/tessdatamanager.h +++ b/ccutil/tessdatamanager.h @@ -25,6 +25,7 @@ #include "host.h" #include "strngs.h" #include "tprintf.h" +#include "version.h" static const char kTrainedDataSuffix[] = "traineddata"; @@ -51,6 +52,9 @@ static const char kLSTMModelFileSuffix[] = "lstm"; static const char kLSTMPuncDawgFileSuffix[] = "lstm-punc-dawg"; static const char kLSTMSystemDawgFileSuffix[] = "lstm-word-dawg"; static const char kLSTMNumberDawgFileSuffix[] = "lstm-number-dawg"; +static const char kLSTMUnicharsetFileSuffix[] = "lstm-unicharset"; +static const char kLSTMRecoderFileSuffix[] = "lstm-recoder"; +static const char kVersionFileSuffix[] = "version"; namespace tesseract { @@ -76,6 +80,9 @@ enum TessdataType { TESSDATA_LSTM_PUNC_DAWG, // 18 TESSDATA_LSTM_SYSTEM_DAWG, // 19 TESSDATA_LSTM_NUMBER_DAWG, // 20 + TESSDATA_LSTM_UNICHARSET, // 21 + TESSDATA_LSTM_RECODER, // 22 + TESSDATA_VERSION, // 23 TESSDATA_NUM_ENTRIES }; @@ -106,6 +113,9 @@ static const char *const kTessdataFileSuffixes[] = { kLSTMPuncDawgFileSuffix, // 18 kLSTMSystemDawgFileSuffix, // 19 kLSTMNumberDawgFileSuffix, // 20 + kLSTMUnicharsetFileSuffix, // 21 + kLSTMRecoderFileSuffix, // 22 + kVersionFileSuffix, // 23 }; /** @@ -120,9 +130,13 @@ static const int kMaxNumTessdataEntries = 1000; class TessdataManager { public: - TessdataManager() : reader_(nullptr), is_loaded_(false), swap_(false) {} + TessdataManager() : reader_(nullptr), is_loaded_(false), swap_(false) { + SetVersionString(TESSERACT_VERSION_STR); + } explicit TessdataManager(FileReader reader) - : reader_(reader), is_loaded_(false), swap_(false) {} + : reader_(reader), is_loaded_(false), swap_(false) { + SetVersionString(TESSERACT_VERSION_STR); + } ~TessdataManager() {} bool swap() const { return swap_; } @@ -152,9 +166,21 @@ class TessdataManager { // Prints a directory of contents. void Directory() const; + // Returns true if the component requested is present. + bool IsComponentAvailable(TessdataType type) const { + return !entries_[type].empty(); + } // Opens the given TFile pointer to the given component type. // Returns false in case of failure. bool GetComponent(TessdataType type, TFile *fp); + // As non-const version except it can't load the component if not already + // loaded. + bool GetComponent(TessdataType type, TFile *fp) const; + + // Returns the current version string. + string VersionString() const; + // Sets the version string to the given v_str. + void SetVersionString(const string &v_str); // Returns true if the base Tesseract components are present. bool IsBaseAvailable() const { diff --git a/ccutil/version.h b/ccutil/version.h new file mode 100644 index 0000000000..3eac67d050 --- /dev/null +++ b/ccutil/version.h @@ -0,0 +1,9 @@ +#ifndef TESSERACT_CCUTIL_VERSION_H_ +#define TESSERACT_CCUTIL_VERSION_H_ + +#define TESSERACT_VERSION_STR "4.00.00alpha" +#define TESSERACT_VERSION 0x040000 +#define MAKE_VERSION(major, minor, patch) \ + (((major) << 16) | ((minor) << 8) | (patch)) + +#endif // TESSERACT_CCUTIL_VERSION_H_ diff --git a/lstm/lstmrecognizer.cpp b/lstm/lstmrecognizer.cpp index 8016696dba..9c1ee74972 100644 --- a/lstm/lstmrecognizer.cpp +++ b/lstm/lstmrecognizer.cpp @@ -68,10 +68,24 @@ LSTMRecognizer::~LSTMRecognizer() { delete search_; } +// Loads a model from mgr, including the dictionary only if lang is not null. +bool LSTMRecognizer::Load(const char* lang, TessdataManager* mgr) { + TFile fp; + if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) return false; + if (!DeSerialize(mgr, &fp)) return false; + if (lang == nullptr) return true; + // Allow it to run without a dictionary. + LoadDictionary(lang, mgr); + return true; +} + // Writes to the given file. Returns false in case of error. -bool LSTMRecognizer::Serialize(TFile* fp) const { +bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const { + bool include_charsets = mgr == nullptr || + !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || + !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); if (!network_->Serialize(fp)) return false; - if (!GetUnicharset().save_to_file(fp)) return false; + if (include_charsets && !GetUnicharset().save_to_file(fp)) return false; if (!network_str_.Serialize(fp)) return false; if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1) return false; @@ -83,16 +97,20 @@ bool LSTMRecognizer::Serialize(TFile* fp) const { if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false; if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; - if (IsRecoding() && !recoder_.Serialize(fp)) return false; + if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false; return true; } // Reads from the given file. Returns false in case of error. -bool LSTMRecognizer::DeSerialize(TFile* fp) { +bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) { delete network_; network_ = Network::CreateFromFile(fp); if (network_ == NULL) return false; - if (!ccutil_.unicharset.load_from_file(fp, false)) return false; + bool include_charsets = mgr == nullptr || + !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) || + !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET); + if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) + return false; if (!network_str_.DeSerialize(fp)) return false; if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1) return false; @@ -107,6 +125,25 @@ bool LSTMRecognizer::DeSerialize(TFile* fp) { if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false; + if (include_charsets && !LoadRecoder(fp)) return false; + if (!include_charsets && !LoadCharsets(mgr)) return false; + network_->SetRandomizer(&randomizer_); + network_->CacheXScaleFactor(network_->XScaleFactor()); + return true; +} + +// Loads the charsets from mgr. +bool LSTMRecognizer::LoadCharsets(const TessdataManager* mgr) { + TFile fp; + if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false; + if (!ccutil_.unicharset.load_from_file(&fp, false)) return false; + if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false; + if (!LoadRecoder(&fp)) return false; + return true; +} + +// Loads the Recoder. +bool LSTMRecognizer::LoadRecoder(TFile* fp) { if (IsRecoding()) { if (!recoder_.DeSerialize(fp)) return false; RecodedCharID code; @@ -119,8 +156,6 @@ bool LSTMRecognizer::DeSerialize(TFile* fp) { recoder_.SetupPassThrough(GetUnicharset()); training_flags_ |= TF_COMPRESS_UNICHARSET; } - network_->SetRandomizer(&randomizer_); - network_->CacheXScaleFactor(network_->XScaleFactor()); return true; } diff --git a/lstm/lstmrecognizer.h b/lstm/lstmrecognizer.h index 51dddcc068..da566dfebf 100644 --- a/lstm/lstmrecognizer.h +++ b/lstm/lstmrecognizer.h @@ -155,10 +155,20 @@ class LSTMRecognizer { } int null_char() const { return null_char_; } + // Loads a model from mgr, including the dictionary only if lang is not null. + bool Load(const char* lang, TessdataManager* mgr); + // Writes to the given file. Returns false in case of error. - bool Serialize(TFile* fp) const; + // If mgr contains a unicharset and recoder, then they are not encoded to fp. + bool Serialize(const TessdataManager* mgr, TFile* fp) const; // Reads from the given file. Returns false in case of error. - bool DeSerialize(TFile* fp); + // If mgr contains a unicharset and recoder, then they are taken from there, + // otherwise, they are part of the serialization in fp. + bool DeSerialize(const TessdataManager* mgr, TFile* fp); + // Loads the charsets from mgr. + bool LoadCharsets(const TessdataManager* mgr); + // Loads the Recoder. + bool LoadRecoder(TFile* fp); // Loads the dictionary if possible from the traineddata file. // Prints a warning message, and returns false but otherwise fails silently // and continues to work without it if loading fails. diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp index afe07585b5..30e04723c5 100644 --- a/lstm/lstmtrainer.cpp +++ b/lstm/lstmtrainer.cpp @@ -93,7 +93,8 @@ LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer, file_writer_(file_writer), checkpoint_reader_(checkpoint_reader), checkpoint_writer_(checkpoint_writer), - sub_trainer_(NULL) { + sub_trainer_(NULL), + mgr_(file_reader) { EmptyConstructor(); if (file_reader_ == NULL) file_reader_ = LoadDataFromFile; if (file_writer_ == NULL) file_writer_ = SaveDataToFile; @@ -145,27 +146,6 @@ void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, SetUnicharsetProperties(script_dir); } -// Initializes the character set encode/decode mechanism directly from a -// previously setup UNICHARSET and UnicharCompress. -// ctc_mode controls how the truth text is mapped to the network targets. -// Note: Call before InitNetwork! -void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, - const UnicharCompress& recoder) { - EmptyConstructor(); - int flags = TF_COMPRESS_UNICHARSET; - training_flags_ = static_cast(flags); - ccutil_.unicharset.CopyFrom(unicharset); - recoder_ = recoder; - null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN - : GetUnicharset().size(); - RecodedCharID code; - recoder_.EncodeUnichar(null_char_, &code); - null_char_ = code(0); - // Space should encode as itself. - recoder_.EncodeUnichar(UNICHAR_SPACE, &code); - ASSERT_HOST(code(0) == UNICHAR_SPACE); -} - // Initializes the trainer with a network_spec in the network description // net_flags control network behavior according to the NetworkFlags enum. // There isn't really much difference between them - only where the effects @@ -175,8 +155,7 @@ void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset, bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum) { - // Call after InitCharSet. - ASSERT_HOST(GetUnicharset().size() > SPECIAL_UNICHAR_CODES_COUNT); + mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.string()); weight_range_ = weight_range; learning_rate_ = learning_rate; momentum_ = momentum; @@ -426,8 +405,9 @@ bool LSTMTrainer::TransitionTrainingStage(float error_threshold) { } // Writes to the given file. Returns false in case of error. -bool LSTMTrainer::Serialize(TFile* fp) const { - if (!LSTMRecognizer::Serialize(fp)) return false; +bool LSTMTrainer::Serialize(SerializeAmount serialize_amount, + const TessdataManager* mgr, TFile* fp) const { + if (!LSTMRecognizer::Serialize(mgr, fp)) return false; if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) return false; if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) != @@ -443,9 +423,9 @@ bool LSTMTrainer::Serialize(TFile* fp) const { if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false; if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1) return false; - uinT8 amount = serialize_amount_; + uinT8 amount = serialize_amount; if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false; - if (amount == LIGHT) return true; // We are done. + if (serialize_amount == LIGHT) return true; // We are done. if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1) return false; if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1) @@ -462,7 +442,8 @@ bool LSTMTrainer::Serialize(TFile* fp) const { return false; if (!best_model_data_.Serialize(fp)) return false; if (!worst_model_data_.Serialize(fp)) return false; - if (amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) return false; + if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) + return false; GenericVector sub_data; if (sub_trainer_ != NULL && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data)) return false; @@ -476,8 +457,8 @@ bool LSTMTrainer::Serialize(TFile* fp) const { // Reads from the given file. Returns false in case of error. // NOTE: It is assumed that the trainer is never read cross-endian. -bool LSTMTrainer::DeSerialize(TFile* fp) { - if (!LSTMRecognizer::DeSerialize(fp)) return false; +bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) { + if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false; if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) { // Special case. If we successfully decoded the recognizer, but fail here // then it means we were just given a recognizer, so issue a warning and @@ -653,7 +634,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, } double momentum_factor = 1.0 / (1.0 - momentum_); GenericVector orig_trainer; - SaveTrainingDump(LIGHT, this, &orig_trainer); + samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer); for (int i = 0; i < num_layers; ++i) { Network* layer = GetLayer(layers[i]); num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0; @@ -667,7 +648,7 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, if (ww == LR_DOWN) ww_factor *= factor; // Make a copy of *this, so we can mess about without damaging anything. LSTMTrainer copy_trainer; - copy_trainer.ReadTrainingDump(orig_trainer, ©_trainer); + samples_trainer->ReadTrainingDump(orig_trainer, ©_trainer); // Clear the updates, doing nothing else. copy_trainer.network_->Update(0.0, 0.0, 0); // Adjust the learning rate in each layer. @@ -683,11 +664,11 @@ int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples, if (trainingdata == NULL) continue; // We'll now use this trainer again for each layer. GenericVector updated_trainer; - SaveTrainingDump(LIGHT, ©_trainer, &updated_trainer); + samples_trainer->SaveTrainingDump(LIGHT, ©_trainer, &updated_trainer); for (int i = 0; i < num_layers; ++i) { if (num_weights[i] == 0) continue; LSTMTrainer layer_trainer; - layer_trainer.ReadTrainingDump(updated_trainer, &layer_trainer); + samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer); Network* layer = layer_trainer.GetLayer(layers[i]); // Update the weights in just the layer, and also zero the updates // matrix (to epsilon). @@ -901,30 +882,27 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata, } // Writes the trainer to memory, so that the current training state can be -// restored. +// restored. *this must always be the master trainer that retains the only +// copy of the training data and language model. trainer is the model that is +// actually serialized. bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer* trainer, GenericVector* data) const { TFile fp; fp.OpenWrite(data); - trainer->serialize_amount_ = serialize_amount; - return trainer->Serialize(&fp); + return trainer->Serialize(serialize_amount, &mgr_, &fp); } -// Reads previously saved trainer from memory. -bool LSTMTrainer::ReadTrainingDump(const GenericVector& data, - LSTMTrainer* trainer) { - if (data.size() == 0) { - tprintf("Warning: data size is zero in LSTMTrainer::ReadTrainingDump\n"); +// Restores the model to *this. +bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager* mgr, + const char* data, int size) { + if (size == 0) { + tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n"); return false; } - return trainer->ReadSizedTrainingDump(&data[0], data.size()); -} - -bool LSTMTrainer::ReadSizedTrainingDump(const char* data, int size) { TFile fp; fp.Open(data, size); - return DeSerialize(&fp); + return DeSerialize(mgr, &fp); } // Writes the recognizer to memory, so that it can be used for testing later. @@ -932,20 +910,10 @@ void LSTMTrainer::SaveRecognitionDump(GenericVector* data) const { TFile fp; fp.OpenWrite(data); network_->SetEnableTraining(TS_TEMP_DISABLE); - ASSERT_HOST(LSTMRecognizer::Serialize(&fp)); + ASSERT_HOST(LSTMRecognizer::Serialize(&mgr_, &fp)); network_->SetEnableTraining(TS_RE_ENABLE); } -// Reads and returns a previously saved recognizer from memory. -LSTMRecognizer* LSTMTrainer::ReadRecognitionDump( - const GenericVector& data) { - TFile fp; - fp.Open(&data[0], data.size()); - LSTMRecognizer* recognizer = new LSTMRecognizer; - ASSERT_HOST(recognizer->DeSerialize(&fp)); - return recognizer; -} - // Returns a suitable filename for a training dump, based on the model_base_, // the iteration and the error rates. STRING LSTMTrainer::DumpFilename() const { @@ -963,6 +931,24 @@ void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) { error_rates_[type] = 100.0 * new_error; } +// Private version of InitCharSet above finishes the job after initializing +// the mgr_ data member. +void LSTMTrainer::InitCharSet() { + EmptyConstructor(); + training_flags_ = TF_COMPRESS_UNICHARSET; + // Initialize the unicharset and recoder. + if (!LoadCharsets(&mgr_)) { + ASSERT_HOST( + "Must provide a traineddata containing lstm_unicharset and" + " lstm_recoder!\n" != nullptr); + } + null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN + : GetUnicharset().size(); + RecodedCharID code; + recoder_.EncodeUnichar(null_char_, &code); + null_char_ = code(0); +} + // Factored sub-constructor sets up reasonable default values. void LSTMTrainer::EmptyConstructor() { align_win_ = NULL; @@ -970,7 +956,6 @@ void LSTMTrainer::EmptyConstructor() { ctc_win_ = NULL; recon_win_ = NULL; checkpoint_iteration_ = 0; - serialize_amount_ = FULL; training_stage_ = 0; num_training_stages_ = 2; InitIterations(); @@ -1283,11 +1268,13 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, if (error_rate > best_error_rate_ && iteration < best_iteration_ + kErrorGraphInterval) { // Too soon to record a new point. - if (tester != NULL) - return tester->Run(worst_iteration_, NULL, worst_model_data_, - CurrentTrainingStage()); - else + if (tester != NULL) { + mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], + worst_model_data_.size()); + return tester->Run(worst_iteration_, NULL, mgr_, CurrentTrainingStage()); + } else { return ""; + } } STRING result; // NOTE: there are 2 asymmetries here: @@ -1298,10 +1285,11 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, // between very frequent minima. if (error_rate < best_error_rate_) { // This is a new (global) minimum. - if (tester != NULL) { - if (worst_model_data_.size() != 0) - result = tester->Run(worst_iteration_, worst_error_rates_, - worst_model_data_, CurrentTrainingStage()); + if (tester != nullptr && !worst_model_data_.empty()) { + mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], + worst_model_data_.size()); + result = tester->Run(worst_iteration_, worst_error_rates_, mgr_, + CurrentTrainingStage()); worst_model_data_.truncate(0); best_model_data_ = model_data; } @@ -1324,13 +1312,17 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, } else if (error_rate > best_error_rate_) { // This is a new (local) maximum. if (tester != NULL) { - if (best_model_data_.empty()) { + if (!best_model_data_.empty()) { + mgr_.OverwriteEntry(TESSDATA_LSTM, &best_model_data_[0], + best_model_data_.size()); + result = tester->Run(best_iteration_, best_error_rates_, mgr_, + CurrentTrainingStage()); + } else if (!worst_model_data_.empty()) { // Allow for multiple data points with "worst" error rate. - result = tester->Run(worst_iteration_, worst_error_rates_, - worst_model_data_, CurrentTrainingStage()); - } else { - result = tester->Run(best_iteration_, best_error_rates_, - best_model_data_, CurrentTrainingStage()); + mgr_.OverwriteEntry(TESSDATA_LSTM, &worst_model_data_[0], + worst_model_data_.size()); + result = tester->Run(worst_iteration_, worst_error_rates_, mgr_, + CurrentTrainingStage()); } if (result.length() > 0) best_model_data_.truncate(0); diff --git a/lstm/lstmtrainer.h b/lstm/lstmtrainer.h index 484b75f966..41ea32d68c 100644 --- a/lstm/lstmtrainer.h +++ b/lstm/lstmtrainer.h @@ -79,8 +79,8 @@ typedef TessResultCallback3&, int>* TestCallback; +typedef TessResultCallback4* TestCallback; // Trainer class for LSTM networks. Most of the effort is in creating the // ideal target outputs from the transcription. A box file is used if it is @@ -110,11 +110,16 @@ class LSTMTrainer : public LSTMRecognizer { void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir, int train_flags); // Initializes the character set encode/decode mechanism directly from a - // previously setup UNICHARSET and UnicharCompress. - // ctc_mode controls how the truth text is mapped to the network targets. - // Note: Call before InitNetwork! - void InitCharSet(const UNICHARSET& unicharset, - const UnicharCompress& recoder); + // previously setup traineddata containing dawgs, UNICHARSET and + // UnicharCompress. Note: Call before InitNetwork! + void InitCharSet(const string& traineddata_path) { + ASSERT_HOST(mgr_.Init(traineddata_path.c_str())); + InitCharSet(); + } + void InitCharSet(const TessdataManager& mgr) { + mgr_ = mgr; + InitCharSet(); + } // Initializes the trainer with a network_spec in the network description // net_flags control network behavior according to the NetworkFlags enum. @@ -175,10 +180,6 @@ class LSTMTrainer : public LSTMRecognizer { double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING* results); - void SetSerializeMode(SerializeAmount serialize_amount) const { - serialize_amount_ = serialize_amount; - } - // Provides output on the distribution of weight values. void DebugNetwork(); @@ -213,9 +214,10 @@ class LSTMTrainer : public LSTMRecognizer { int CurrentTrainingStage() const { return training_stage_; } // Writes to the given file. Returns false in case of error. - virtual bool Serialize(TFile* fp) const; + virtual bool Serialize(SerializeAmount serialize_amount, + const TessdataManager* mgr, TFile* fp) const; // Reads from the given file. Returns false in case of error. - virtual bool DeSerialize(TFile* fp); + virtual bool DeSerialize(const TessdataManager* mgr, TFile* fp); // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the // learning rates (by scaling reduction, or layer specific, according to @@ -287,14 +289,27 @@ class LSTMTrainer : public LSTMRecognizer { NetworkIO* fwd_outputs, NetworkIO* targets); // Writes the trainer to memory, so that the current training state can be - // restored. + // restored. *this must always be the master trainer that retains the only + // copy of the training data and language model. trainer is the model that is + // actually serialized. bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer* trainer, GenericVector* data) const; - // Reads previously saved trainer from memory. - bool ReadTrainingDump(const GenericVector& data, LSTMTrainer* trainer); - bool ReadSizedTrainingDump(const char* data, int size); + // Reads previously saved trainer from memory. *this must always be the + // master trainer that retains the only copy of the training data and + // language model. trainer is the model that is restored. + bool ReadTrainingDump(const GenericVector& data, + LSTMTrainer* trainer) const { + return ReadSizedTrainingDump(&data[0], data.size(), trainer); + } + bool ReadSizedTrainingDump(const char* data, int size, + LSTMTrainer* trainer) const { + return trainer->ReadLocalTrainingDump(&mgr_, data, size); + } + // Restores the model to *this. + bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data, + int size); // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump. void SetupCheckpointInfo(); @@ -302,9 +317,6 @@ class LSTMTrainer : public LSTMRecognizer { // Writes the recognizer to memory, so that it can be used for testing later. void SaveRecognitionDump(GenericVector* data) const; - // Reads and returns a previously saved recognizer from memory. - static LSTMRecognizer* ReadRecognitionDump(const GenericVector& data); - // Writes current best model to a file, unless it has already been written. bool SaveBestModel(FileWriter writer) const; @@ -316,6 +328,10 @@ class LSTMTrainer : public LSTMRecognizer { void FillErrorBuffer(double new_error, ErrorTypes type); protected: + // Private version of InitCharSet above finishes the job after initializing + // the mgr_ data member. + void InitCharSet(); + // Factored sub-constructor sets up reasonable default values. void EmptyConstructor(); @@ -404,8 +420,6 @@ class LSTMTrainer : public LSTMRecognizer { STRING checkpoint_name_; // Training data. DocumentCache training_data_; - // A hack to serialize less data for batch training and record file version. - mutable SerializeAmount serialize_amount_; // Name to use when saving best_trainer_. STRING best_model_name_; // Number of available training stages. @@ -419,7 +433,7 @@ class LSTMTrainer : public LSTMRecognizer { CheckPointWriter checkpoint_writer_; // ===Serialized data to ensure that a restart produces the same results.=== - // These members are only serialized when serialize_amount_ != LIGHT. + // These members are only serialized when serialize_amount != LIGHT. // Best error rate so far. double best_error_rate_; // Snapshot of all error rates at best_iteration_. @@ -473,6 +487,8 @@ class LSTMTrainer : public LSTMRecognizer { GenericVector error_buffers_[ET_COUNT]; // Rounded mean percent trailing training errors in the buffers. double error_rates_[ET_COUNT]; // RMS training error. + // Traineddata file with optional dawgs + UNICHARSET and recoder. + TessdataManager mgr_; }; } // namespace tesseract. diff --git a/training/lstmeval.cpp b/training/lstmeval.cpp index aa990e2325..7f61adf27d 100644 --- a/training/lstmeval.cpp +++ b/training/lstmeval.cpp @@ -26,6 +26,9 @@ #include "tprintf.h" STRING_PARAM_FLAG(model, "", "Name of model file (training or recognition)"); +STRING_PARAM_FLAG(traineddata, "", + "If model is a training checkpoint, then traineddata must " + "be the traineddata file that was given to the trainer"); STRING_PARAM_FLAG(eval_listfile, "", "File listing sample files in lstmf training format."); INT_PARAM_FLAG(max_image_MB, 2000, "Max memory to use for images."); @@ -40,10 +43,22 @@ int main(int argc, char **argv) { tprintf("Must provide a --eval_listfile!\n"); return 1; } - GenericVector model_data; - if (!tesseract::LoadDataFromFile(FLAGS_model.c_str(), &model_data)) { - tprintf("Failed to load model from: %s\n", FLAGS_eval_listfile.c_str()); - return 1; + tesseract::TessdataManager mgr; + if (!mgr.Init(FLAGS_model.c_str())) { + tprintf("%s is not a recognition model, trying training checkpoint...\n", + FLAGS_model.c_str()); + if (!mgr.Init(FLAGS_traineddata.c_str())) { + tprintf("Failed to load language model from %s!\n", + FLAGS_traineddata.c_str()); + return 1; + } + GenericVector model_data; + if (!tesseract::LoadDataFromFile(FLAGS_model.c_str(), &model_data)) { + tprintf("Failed to load model from: %s\n", FLAGS_model.c_str()); + return 1; + } + mgr.OverwriteEntry(tesseract::TESSDATA_LSTM, &model_data[0], + model_data.size()); } tesseract::LSTMTester tester(static_cast(FLAGS_max_image_MB) * 1048576); @@ -52,7 +67,7 @@ int main(int argc, char **argv) { return 1; } double errs = 0.0; - STRING result = tester.RunEvalSync(0, &errs, model_data, 0); + STRING result = tester.RunEvalSync(0, &errs, mgr, 0); tprintf("%s\n", result.string()); return 0; } /* main */ diff --git a/training/lstmtester.cpp b/training/lstmtester.cpp index 9947dbd127..50e2c56282 100644 --- a/training/lstmtester.cpp +++ b/training/lstmtester.cpp @@ -50,7 +50,7 @@ bool LSTMTester::LoadAllEvalData(const GenericVector& filenames) { // Runs an evaluation asynchronously on the stored data and returns a string // describing the results of the previous test. STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors, - const GenericVector& model_data, + const TessdataManager& model_mgr, int training_stage) { STRING result; if (total_pages_ == 0) { @@ -68,7 +68,7 @@ STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors, if (training_errors != nullptr) { test_iteration_ = iteration; test_training_errors_ = training_errors; - test_model_data_ = model_data; + test_model_mgr_ = model_mgr; test_training_stage_ = training_stage; SVSync::StartThread(&LSTMTester::ThreadFunc, this); } else { @@ -80,10 +80,13 @@ STRING LSTMTester::RunEvalAsync(int iteration, const double* training_errors, // Runs an evaluation synchronously on the stored data and returns a string // describing the results. STRING LSTMTester::RunEvalSync(int iteration, const double* training_errors, - const GenericVector& model_data, + const TessdataManager& model_mgr, int training_stage) { LSTMTrainer trainer; - if (!trainer.ReadTrainingDump(model_data, &trainer)) { + trainer.InitCharSet(model_mgr); + TFile fp; + if (!model_mgr.GetComponent(TESSDATA_LSTM, &fp) || + !trainer.DeSerialize(&model_mgr, &fp)) { return "Deserialize failed"; } int eval_iteration = 0; @@ -122,7 +125,7 @@ void* LSTMTester::ThreadFunc(void* lstmtester_void) { LSTMTester* lstmtester = static_cast(lstmtester_void); lstmtester->test_result_ = lstmtester->RunEvalSync( lstmtester->test_iteration_, lstmtester->test_training_errors_, - lstmtester->test_model_data_, lstmtester->test_training_stage_); + lstmtester->test_model_mgr_, lstmtester->test_training_stage_); lstmtester->UnlockRunning(); return lstmtester_void; } diff --git a/training/lstmtester.h b/training/lstmtester.h index 3b4cb05e78..e43dd26e95 100644 --- a/training/lstmtester.h +++ b/training/lstmtester.h @@ -53,12 +53,11 @@ class LSTMTester { // LSTMTrainer. // training_stage: an arbitrary number on the progress of training. STRING RunEvalAsync(int iteration, const double* training_errors, - const GenericVector& model_data, - int training_stage); + const TessdataManager& model_mgr, int training_stage); // Runs an evaluation synchronously on the stored eval data and returns a // string describing the results. Args as RunEvalAsync. STRING RunEvalSync(int iteration, const double* training_errors, - const GenericVector& model_data, int training_stage); + const TessdataManager& model_mgr, int training_stage); private: // Static helper thread function for RunEvalAsync, with a specific signature @@ -84,7 +83,7 @@ class LSTMTester { // Stored copies of the args for use while running asynchronously. int test_iteration_; const double* test_training_errors_; - GenericVector test_model_data_; + TessdataManager test_model_mgr_; int test_training_stage_; STRING test_result_; };