Skip to content

Commit

Permalink
Simplify class LSTMTrainer
Browse files Browse the repository at this point in the history
The function pointers and callbacks file_reader_, file_writer_,
checkpointer_reader_ and checkpoint_writer_ are always set to
the same values. Replacing them by direct function calls
simplifies the code and allows removing more code from tesscallback.h.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jun 22, 2019
1 parent dff33d6 commit bd13069
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 172 deletions.
105 changes: 0 additions & 105 deletions src/ccutil/tesscallback.h
Expand Up @@ -607,72 +607,6 @@ NewPermanentTessCallback(R (*function)(P1, A1),
return new _TessFunctionResultCallback_1_1<false, R, P1, A1>(function, p1);
}

template <bool del, class R, class T, class A1, class A2>
class _ConstTessMemberResultCallback_0_2
: public TessResultCallback2<R, A1, A2> {
public:
typedef TessResultCallback2<R, A1, A2> base;
using MemberSignature = R (T::*)(A1, A2) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_2(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

R Run(A1 a1, A2 a2) override {
if (!del) {
R result = (object_->*member_)(a1, a2);
return result;
}
R result = (object_->*member_)(a1, a2);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
return result;
}
};

template <bool del, class T, class A1, class A2>
class _ConstTessMemberResultCallback_0_2<del, void, T, A1, A2>
: public TessCallback2<A1, A2> {
public:
typedef TessCallback2<A1, A2> base;
using MemberSignature = void (T::*)(A1, A2) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_2(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

virtual void Run(A1 a1, A2 a2) {
if (!del) {
(object_->*member_)(a1, a2);
} else {
(object_->*member_)(a1, a2);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
}
}
};

#ifndef SWIG
template <class T1, class T2, class R, class A1, class A2>
inline typename _ConstTessMemberResultCallback_0_2<false, R, T1, A1, A2>::base*
NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2) const) {
return new _ConstTessMemberResultCallback_0_2<false, R, T1, A1, A2>(obj,
member);
}
#endif

template <bool del, class R, class T, class A1, class A2>
class _TessMemberResultCallback_0_2 : public TessResultCallback2<R, A1, A2> {
public:
Expand Down Expand Up @@ -793,45 +727,6 @@ NewPermanentTessCallback(R (*function)(A1, A2)) {
return new _TessFunctionResultCallback_0_2<false, R, A1, A2>(function);
}

template <bool del, class R, class T, class A1, class A2, class A3>
class _ConstTessMemberResultCallback_0_3
: public TessResultCallback3<R, A1, A2, A3> {
public:
typedef TessResultCallback3<R, A1, A2, A3> base;
using MemberSignature = R (T::*)(A1, A2, A3) const;

private:
const T* object_;
MemberSignature member_;

public:
inline _ConstTessMemberResultCallback_0_3(const T* object,
MemberSignature member)
: object_(object), member_(member) {}

R Run(A1 a1, A2 a2, A3 a3) override {
if (!del) {
R result = (object_->*member_)(a1, a2, a3);
return result;
}
R result = (object_->*member_)(a1, a2, a3);
// zero out the pointer to ensure segfault if used again
member_ = nullptr;
delete this;
return result;
}
};

#ifndef SWIG
template <class T1, class T2, class R, class A1, class A2, class A3>
inline
typename _ConstTessMemberResultCallback_0_3<false, R, T1, A1, A2, A3>::base*
NewPermanentTessCallback(const T1* obj, R (T2::*member)(A1, A2, A3) const) {
return new _ConstTessMemberResultCallback_0_3<false, R, T1, A1, A2, A3>(
obj, member);
}
#endif

template <bool del, class R, class T, class A1, class A2, class A3, class A4>
class _TessMemberResultCallback_0_4
: public TessResultCallback4<R, A1, A2, A3, A4> {
Expand Down
58 changes: 16 additions & 42 deletions src/lstm/lstmtrainer.cpp
Expand Up @@ -74,41 +74,17 @@ const int kTargetYScale = 100;
LSTMTrainer::LSTMTrainer()
: randomly_rotate_(false),
training_data_(0),
file_reader_(LoadDataFromFile),
file_writer_(SaveDataToFile),
checkpoint_reader_(
NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)),
checkpoint_writer_(
NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)),
sub_trainer_(nullptr) {
EmptyConstructor();
debug_interval_ = 0;
}

LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointReader checkpoint_reader,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
LSTMTrainer::LSTMTrainer(const char* model_base, const char* checkpoint_name,
int debug_interval, int64_t max_memory)
: randomly_rotate_(false),
training_data_(max_memory),
file_reader_(file_reader),
file_writer_(file_writer),
checkpoint_reader_(checkpoint_reader),
checkpoint_writer_(checkpoint_writer),
sub_trainer_(nullptr),
mgr_(file_reader) {
sub_trainer_(nullptr) {
EmptyConstructor();
if (file_reader_ == nullptr) file_reader_ = LoadDataFromFile;
if (file_writer_ == nullptr) file_writer_ = SaveDataToFile;
if (checkpoint_reader_ == nullptr) {
checkpoint_reader_ =
NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump);
}
if (checkpoint_writer_ == nullptr) {
checkpoint_writer_ =
NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump);
}
debug_interval_ = debug_interval;
model_base_ = model_base;
checkpoint_name_ = checkpoint_name;
Expand All @@ -119,8 +95,6 @@ LSTMTrainer::~LSTMTrainer() {
delete target_win_;
delete ctc_win_;
delete recon_win_;
delete checkpoint_reader_;
delete checkpoint_writer_;
delete sub_trainer_;
}

Expand All @@ -129,9 +103,9 @@ LSTMTrainer::~LSTMTrainer() {
bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
const char* old_traineddata) {
GenericVector<char> data;
if (!(*file_reader_)(filename, &data)) return false;
if (!LoadDataFromFile(filename, &data)) return false;
tprintf("Loaded file %s, unpacking...\n", filename);
if (!checkpoint_reader_->Run(data, this)) return false;
if (!ReadTrainingDump(data, this)) return false;
StaticShape shape = network_->OutputShape(network_->InputShape());
if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
network_->NumOutputs() == recoder_.code_range()) ||
Expand Down Expand Up @@ -303,7 +277,8 @@ bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
bool randomly_rotate) {
randomly_rotate_ = randomly_rotate;
training_data_.Clear();
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
return training_data_.LoadDocuments(filenames, cache_strategy,
LoadDataFromFile);
}

// Keeps track of best and locally worst char error_rate and launches tests
Expand Down Expand Up @@ -345,10 +320,10 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
if (TransitionTrainingStage(kStageTransitionThreshold)) {
log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
}
checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_);
if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
STRING best_model_name = DumpFilename();
if (!(*file_writer_)(best_trainer_, best_model_name)) {
if (!SaveDataToFile(best_trainer_, best_model_name)) {
*log_msg += " failed to write best model:";
} else {
*log_msg += " wrote best model:";
Expand All @@ -366,7 +341,7 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
*log_msg += "\nDivergence! ";
// Copy best_trainer_ before reading it, as it will get overwritten.
GenericVector<char> revert_data(best_trainer_);
if (checkpoint_reader_->Run(revert_data, this)) {
if (ReadTrainingDump(revert_data, this)) {
LogIterations("Reverted to", log_msg);
ReduceLearningRates(this, log_msg);
} else {
Expand All @@ -376,18 +351,17 @@ bool LSTMTrainer::MaintainCheckpoints(TestCallback tester, STRING* log_msg) {
stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
// Re-save the best trainer with the new learning rates and stall
// iteration.
checkpoint_writer_->Run(NO_BEST_TRAINER, this, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, this, &best_trainer_);
}
} else {
// Something interesting happened only if the sub_trainer_ was trained.
result = sub_trainer_result != STR_NONE;
}
if (checkpoint_writer_ != nullptr && file_writer_ != nullptr &&
checkpoint_name_.length() > 0) {
if (checkpoint_name_.length() > 0) {
// Write a current checkpoint.
GenericVector<char> checkpoint;
if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
!(*file_writer_)(checkpoint, checkpoint_name_)) {
if (!SaveTrainingDump(FULL, this, &checkpoint) ||
!SaveDataToFile(checkpoint, checkpoint_name_)) {
*log_msg += " failed to write checkpoint.";
} else {
*log_msg += " wrote checkpoint.";
Expand Down Expand Up @@ -518,7 +492,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
void LSTMTrainer::StartSubtrainer(STRING* log_msg) {
delete sub_trainer_;
sub_trainer_ = new LSTMTrainer();
if (!checkpoint_reader_->Run(best_trainer_, sub_trainer_)) {
if (!ReadTrainingDump(best_trainer_, sub_trainer_)) {
*log_msg += " Failed to revert to previous best for trial!";
delete sub_trainer_;
sub_trainer_ = nullptr;
Expand All @@ -533,7 +507,7 @@ void LSTMTrainer::StartSubtrainer(STRING* log_msg) {
stall_iteration_ = learning_iteration() + 2 * stall_offset;
sub_trainer_->stall_iteration_ = stall_iteration_;
// Re-save the best trainer with the new learning rates and stall iteration.
checkpoint_writer_->Run(NO_BEST_TRAINER, sub_trainer_, &best_trainer_);
SaveTrainingDump(NO_BEST_TRAINER, sub_trainer_, &best_trainer_);
}
}

Expand Down Expand Up @@ -926,7 +900,7 @@ bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
SaveRecognitionDump(&recognizer_data);
mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
recognizer_data.size());
return mgr_.SaveFile(filename, file_writer_);
return mgr_.SaveFile(filename, SaveDataToFile);
}

// Writes the recognizer to memory, so that it can be used for testing later.
Expand Down
23 changes: 1 addition & 22 deletions src/lstm/lstmtrainer.h
Expand Up @@ -2,7 +2,6 @@
// File: lstmtrainer.h
// Description: Top-level line trainer class for LSTM-based networks.
// Author: Ray Smith
// Created: Fri May 03 09:07:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -67,15 +66,6 @@ enum SubTrainerResult {
};

class LSTMTrainer;
// Function to restore the trainer state from a given checkpoint.
// Returns false on failure.
typedef TessResultCallback2<bool, const GenericVector<char>&, LSTMTrainer*>*
CheckPointReader;
// Function to save a checkpoint of the current trainer state.
// Returns false on failure. SerializeAmount determines the amount of the
// trainer to serialize, typically used for saving the best state.
typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
GenericVector<char>*>* CheckPointWriter;
// Function to compute and record error rates on some external test set(s).
// Args are: iteration, mean errors, model, training stage.
// Returns a STRING containing logging information about the tests.
Expand All @@ -89,11 +79,7 @@ typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
class LSTMTrainer : public LSTMRecognizer {
public:
LSTMTrainer();
// Callbacks may be null, in which case defaults are used.
LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointReader checkpoint_reader,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
LSTMTrainer(const char* model_base, const char* checkpoint_name,
int debug_interval, int64_t max_memory);
virtual ~LSTMTrainer();

Expand Down Expand Up @@ -416,13 +402,6 @@ class LSTMTrainer : public LSTMRecognizer {
STRING best_model_name_;
// Number of available training stages.
int num_training_stages_;
// Checkpointing callbacks.
FileReader file_reader_;
FileWriter file_writer_;
// TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
// when we can commit to c++11.
CheckPointReader checkpoint_reader_;
CheckPointWriter checkpoint_writer_;

// ===Serialized data to ensure that a restart produces the same results.===
// These members are only serialized when serialize_amount != LIGHT.
Expand Down
2 changes: 1 addition & 1 deletion src/training/lstmtraining.cpp
Expand Up @@ -103,7 +103,7 @@ int main(int argc, char **argv) {
checkpoint_file += "_checkpoint";
STRING checkpoint_bak = checkpoint_file + ".bak";
tesseract::LSTMTrainer trainer(
nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
FLAGS_model_output.c_str(),
checkpoint_file.c_str(), FLAGS_debug_interval,
static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
trainer.InitCharSet(FLAGS_traineddata.c_str());
Expand Down
3 changes: 1 addition & 2 deletions unittest/lstm_test.h
Expand Up @@ -84,8 +84,7 @@ class LSTMTrainerTest : public testing::Test {
nullptr, nullptr));
std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
std::string checkpoint_path = model_path + "_checkpoint";
trainer_.reset(new LSTMTrainer(nullptr, nullptr, nullptr, nullptr,
model_path.c_str(), checkpoint_path.c_str(),
trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
0, 0));
trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
absl::StrCat(kLang, ".traineddata")));
Expand Down

0 comments on commit bd13069

Please sign in to comment.