diff --git a/src/ccutil/tesscallback.h b/src/ccutil/tesscallback.h index 97e5c05d67..03e7797457 100644 --- a/src/ccutil/tesscallback.h +++ b/src/ccutil/tesscallback.h @@ -607,72 +607,6 @@ NewPermanentTessCallback(R (*function)(P1, A1), return new _TessFunctionResultCallback_1_1(function, p1); } -template -class _ConstTessMemberResultCallback_0_2 - : public TessResultCallback2 { - public: - typedef TessResultCallback2 base; - using MemberSignature = R (T::*)(A1, A2) const; - - private: - const T* object_; - MemberSignature member_; - - public: - inline _ConstTessMemberResultCallback_0_2(const T* object, - MemberSignature member) - : object_(object), member_(member) {} - - R Run(A1 a1, A2 a2) override { - if (!del) { - R result = (object_->*member_)(a1, a2); - return result; - } - R result = (object_->*member_)(a1, a2); - // zero out the pointer to ensure segfault if used again - member_ = nullptr; - delete this; - return result; - } -}; - -template -class _ConstTessMemberResultCallback_0_2 - : public TessCallback2 { - public: - typedef TessCallback2 base; - using MemberSignature = void (T::*)(A1, A2) const; - - private: - const T* object_; - MemberSignature member_; - - public: - inline _ConstTessMemberResultCallback_0_2(const T* object, - MemberSignature member) - : object_(object), member_(member) {} - - virtual void Run(A1 a1, A2 a2) { - if (!del) { - (object_->*member_)(a1, a2); - } else { - (object_->*member_)(a1, a2); - // zero out the pointer to ensure segfault if used again - member_ = nullptr; - delete this; - } - } -}; - -#ifndef SWIG -template -inline typename _ConstTessMemberResultCallback_0_2::base* -NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2) const) { - return new _ConstTessMemberResultCallback_0_2(obj, - member); -} -#endif - template class _TessMemberResultCallback_0_2 : public TessResultCallback2 { public: @@ -793,45 +727,6 @@ NewPermanentTessCallback(R (*function)(A1, A2)) { return new _TessFunctionResultCallback_0_2(function); } -template -class _ConstTessMemberResultCallback_0_3 - : public TessResultCallback3 { - public: - typedef TessResultCallback3 base; - using MemberSignature = R (T::*)(A1, A2, A3) const; - - private: - const T* object_; - MemberSignature member_; - - public: - inline _ConstTessMemberResultCallback_0_3(const T* object, - MemberSignature member) - : object_(object), member_(member) {} - - R Run(A1 a1, A2 a2, A3 a3) override { - if (!del) { - R result = (object_->*member_)(a1, a2, a3); - return result; - } - R result = (object_->*member_)(a1, a2, a3); - // zero out the pointer to ensure segfault if used again - member_ = nullptr; - delete this; - return result; - } -}; - -#ifndef SWIG -template -inline - typename _ConstTessMemberResultCallback_0_3::base* - NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2, A3) const) { - return new _ConstTessMemberResultCallback_0_3( - obj, member); -} -#endif - template class _TessMemberResultCallback_0_4 : public TessResultCallback4 { diff --git a/src/lstm/lstmtrainer.cpp b/src/lstm/lstmtrainer.cpp index fa4bb1cf84..434e52cb2e 100644 --- a/src/lstm/lstmtrainer.cpp +++ b/src/lstm/lstmtrainer.cpp @@ -74,41 +74,17 @@ const int kTargetYScale = 100; LSTMTrainer::LSTMTrainer() : randomly_rotate_(false), training_data_(0), - file_reader_(LoadDataFromFile), - file_writer_(SaveDataToFile), - checkpoint_reader_( - NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)), - checkpoint_writer_( - NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)), sub_trainer_(nullptr) { EmptyConstructor(); debug_interval_ = 0; } -LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer, - CheckPointReader checkpoint_reader, - CheckPointWriter checkpoint_writer, - const char* model_base, const char* checkpoint_name, +LSTMTrainer::LSTMTrainer(const char* model_base, const char* checkpoint_name, int debug_interval, int64_t max_memory) : randomly_rotate_(false), training_data_(max_memory), - file_reader_(file_reader), - file_writer_(file_writer), - checkpoint_reader_(checkpoint_reader), - checkpoint_writer_(checkpoint_writer), - sub_trainer_(nullptr), - mgr_(file_reader) { + sub_trainer_(nullptr) { EmptyConstructor(); - if (file_reader_ == nullptr) file_reader_ = LoadDataFromFile; - if (file_writer_ == nullptr) file_writer_ = SaveDataToFile; - if (checkpoint_reader_ == nullptr) { - checkpoint_reader_ = - NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump); - } - if (checkpoint_writer_ == nullptr) { - checkpoint_writer_ = - NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump); - } debug_interval_ = debug_interval; model_base_ = model_base; checkpoint_name_ = checkpoint_name; @@ -119,8 +95,6 @@ LSTMTrainer::~LSTMTrainer() { delete target_win_; delete ctc_win_; delete recon_win_; - delete checkpoint_reader_; - delete checkpoint_writer_; delete sub_trainer_; } @@ -129,9 +103,9 @@ LSTMTrainer::~LSTMTrainer() { bool LSTMTrainer::TryLoadingCheckpoint(const char* filename, const char* old_traineddata) { GenericVector data; - if (!(*file_reader_)(filename, &data)) return false; + if (!LoadDataFromFile(filename, &data)) return false; tprintf("Loaded file %s, unpacking...\n", filename); - if (!checkpoint_reader_->Run(data, this)) return false; + if (!ReadTrainingDump(data, this)) return false; StaticShape shape = network_->OutputShape(network_->InputShape()); if (((old_traineddata == nullptr || *old_traineddata == '\0') && network_->NumOutputs() == recoder_.code_range()) || @@ -303,7 +277,8 @@ bool LSTMTrainer::LoadAllTrainingData(const GenericVector& filenames, bool randomly_rotate) { randomly_rotate_ = randomly_rotate; training_data_.Clear(); - return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_); + return training_data_.LoadDocuments(filenames, cache_strategy, + LoadDataFromFile); } // Keeps track of best and locally worst char error_rate and launches tests @@ -345,10 +320,10 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) { if (TransitionTrainingStage(kStageTransitionThreshold)) { log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage()); } - checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_); + SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_); if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) { STRING best_model_name = DumpFilename(); - if (!(*file_writer_)(best_trainer_, best_model_name)) { + if (!SaveDataToFile(best_trainer_, best_model_name)) { *log_msg += " failed to write best model:"; } else { *log_msg += " wrote best model:"; @@ -366,7 +341,7 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) { *log_msg += "\nDivergence! "; // Copy best_trainer_ before reading it, as it will get overwritten. GenericVector revert_data(best_trainer_); - if (checkpoint_reader_->Run(revert_data, this)) { + if (ReadTrainingDump(revert_data, this)) { LogIterations("Reverted to", log_msg); ReduceLearningRates(this, log_msg); } else { @@ -376,18 +351,17 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) { stall_iteration_ = iteration + 2 * (iteration - learning_iteration()); // Re-save the best trainer with the new learning rates and stall // iteration. - checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_); + SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_); } } else { // Something interesting happened only if the sub_trainer_ was trained. result = sub_trainer_result != STR_NONE; } - if (checkpoint_writer_ != nullptr && file_writer_ != nullptr && - checkpoint_name_.length() > 0) { + if (checkpoint_name_.length() > 0) { // Write a current checkpoint. GenericVector checkpoint; - if (!checkpoint_writer_->Run(FULL, this, &checkpoint) || - !(*file_writer_)(checkpoint, checkpoint_name_)) { + if (!SaveTrainingDump(FULL, this, &checkpoint) || + !SaveDataToFile(checkpoint, checkpoint_name_)) { *log_msg += " failed to write checkpoint."; } else { *log_msg += " wrote checkpoint."; @@ -518,7 +492,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) { void LSTMTrainer::StartSubtrainer(STRING* log_msg) { delete sub_trainer_; sub_trainer_ = new LSTMTrainer(); - if (!checkpoint_reader_->Run(best_trainer_, sub_trainer_)) { + if (!ReadTrainingDump(best_trainer_, sub_trainer_)) { *log_msg += " Failed to revert to previous best for trial!"; delete sub_trainer_; sub_trainer_ = nullptr; @@ -533,7 +507,7 @@ void LSTMTrainer::StartSubtrainer(STRING* log_msg) { stall_iteration_ = learning_iteration() + 2 * stall_offset; sub_trainer_->stall_iteration_ = stall_iteration_; // Re-save the best trainer with the new learning rates and stall iteration. - checkpoint_writer_->Run(NO_BEST_TRAINER, sub_trainer_, &best_trainer_); + SaveTrainingDump(NO_BEST_TRAINER, sub_trainer_, &best_trainer_); } } @@ -926,7 +900,7 @@ bool LSTMTrainer::SaveTraineddata(const STRING& filename) { SaveRecognitionDump(&recognizer_data); mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0], recognizer_data.size()); - return mgr_.SaveFile(filename, file_writer_); + return mgr_.SaveFile(filename, SaveDataToFile); } // Writes the recognizer to memory, so that it can be used for testing later. diff --git a/src/lstm/lstmtrainer.h b/src/lstm/lstmtrainer.h index 82a8c9b83c..a9e6bf5f19 100644 --- a/src/lstm/lstmtrainer.h +++ b/src/lstm/lstmtrainer.h @@ -2,7 +2,6 @@ // File: lstmtrainer.h // Description: Top-level line trainer class for LSTM-based networks. // Author: Ray Smith -// Created: Fri May 03 09:07:06 PST 2013 // // (C) Copyright 2013, Google Inc. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -67,15 +66,6 @@ enum SubTrainerResult { }; class LSTMTrainer; -// Function to restore the trainer state from a given checkpoint. -// Returns false on failure. -typedef TessResultCallback2&, LSTMTrainer*>* - CheckPointReader; -// Function to save a checkpoint of the current trainer state. -// Returns false on failure. SerializeAmount determines the amount of the -// trainer to serialize, typically used for saving the best state. -typedef TessResultCallback3*>* CheckPointWriter; // Function to compute and record error rates on some external test set(s). // Args are: iteration, mean errors, model, training stage. // Returns a STRING containing logging information about the tests. @@ -89,11 +79,7 @@ typedef TessResultCallback4(FLAGS_max_image_MB) * 1048576); trainer.InitCharSet(FLAGS_traineddata.c_str()); diff --git a/unittest/lstm_test.h b/unittest/lstm_test.h index 06c2320d46..0a22fddfd1 100644 --- a/unittest/lstm_test.h +++ b/unittest/lstm_test.h @@ -84,8 +84,7 @@ class LSTMTrainerTest : public testing::Test { nullptr, nullptr)); std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name); std::string checkpoint_path = model_path + "_checkpoint"; - trainer_.reset(new LSTMTrainer(nullptr, nullptr, nullptr, nullptr, - model_path.c_str(), checkpoint_path.c_str(), + trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(), 0, 0)); trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang, absl::StrCat(kLang, ".traineddata")));