Skip to content

Commit

Permalink
LSTMTrainer: Use new serialization API
Browse files Browse the repository at this point in the history
Improve also portability by using int32_t instead of int
for a serialized member variable.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jul 18, 2018
1 parent 1dcda1a commit b7b8dba
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 62 deletions.
90 changes: 30 additions & 60 deletions src/lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,38 +431,25 @@ bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
const TessdataManager* mgr, TFile* fp) const {
if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
return false;
if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
1)
return false;
if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
if (fp->FWrite(&last_perfect_training_iteration_,
sizeof(last_perfect_training_iteration_), 1) != 1)
return false;
if (!fp->Serialize(&learning_iteration_)) return false;
if (!fp->Serialize(&prev_sample_iteration_)) return false;
if (!fp->Serialize(&perfect_delay_)) return false;
if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
for (int i = 0; i < ET_COUNT; ++i) {
if (!error_buffers_[i].Serialize(fp)) return false;
}
if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
return false;
if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
if (!fp->Serialize(&training_stage_)) return false;
uint8_t amount = serialize_amount;
if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
if (!fp->Serialize(&amount)) return false;
if (serialize_amount == LIGHT) return true; // We are done.
if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
return false;
if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
return false;
if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
return false;
if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
return false;
if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
return false;
if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
return false;
if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
return false;
if (!fp->Serialize(&best_error_rate_)) return false;
if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
if (!fp->Serialize(&best_iteration_)) return false;
if (!fp->Serialize(&worst_error_rate_)) return false;
if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
if (!fp->Serialize(&worst_iteration_)) return false;
if (!fp->Serialize(&stall_iteration_)) return false;
if (!best_model_data_.Serialize(fp)) return false;
if (!worst_model_data_.Serialize(fp)) return false;
if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
Expand All @@ -473,16 +460,14 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
if (!sub_data.Serialize(fp)) return false;
if (!best_error_history_.Serialize(fp)) return false;
if (!best_error_iterations_.Serialize(fp)) return false;
if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
return false;
return true;
return fp->Serialize(&improvement_steps_);
}

// Reads from the given file. Returns false in case of error.
// NOTE: It is assumed that the trainer is never read cross-endian.
bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
if (!fp->DeSerialize(&learning_iteration_)) {
// Special case. If we successfully decoded the recognizer, but fail here
// then it means we were just given a recognizer, so issue a warning and
// allow it.
Expand All @@ -491,37 +476,24 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
network_->SetEnableTraining(TS_ENABLED);
return true;
}
if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_),
1) != 1)
return false;
if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
return false;
if (fp->FReadEndian(&last_perfect_training_iteration_,
sizeof(last_perfect_training_iteration_), 1) != 1)
return false;
if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
if (!fp->DeSerialize(&perfect_delay_)) return false;
if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
for (int i = 0; i < ET_COUNT; ++i) {
if (!error_buffers_[i].DeSerialize(fp)) return false;
}
if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
return false;
if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
if (!fp->DeSerialize(&training_stage_)) return false;
uint8_t amount;
if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
if (!fp->DeSerialize(&amount)) return false;
if (amount == LIGHT) return true; // Don't read the rest.
if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
return false;
if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
return false;
if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
return false;
if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
return false;
if (!fp->DeSerialize(&best_error_rate_)) return false;
if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
if (!fp->DeSerialize(&best_iteration_)) return false;
if (!fp->DeSerialize(&worst_error_rate_)) return false;
if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
if (!fp->DeSerialize(&worst_iteration_)) return false;
if (!fp->DeSerialize(&stall_iteration_)) return false;
if (!best_model_data_.DeSerialize(fp)) return false;
if (!worst_model_data_.DeSerialize(fp)) return false;
if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
Expand All @@ -536,9 +508,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
}
if (!best_error_history_.DeSerialize(fp)) return false;
if (!best_error_iterations_.DeSerialize(fp)) return false;
if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
return false;
return true;
return fp->DeSerialize(&improvement_steps_);
}

// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
Expand Down
4 changes: 2 additions & 2 deletions src/lstm/lstmtrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class LSTMTrainer : public LSTMRecognizer {
return best_iteration_;
}
int learning_iteration() const { return learning_iteration_; }
int improvement_steps() const { return improvement_steps_; }
int32_t improvement_steps() const { return improvement_steps_; }
void set_perfect_delay(int delay) { perfect_delay_ = delay; }
const GenericVector<char>& best_trainer() const { return best_trainer_; }
// Returns the error that was just calculated by PrepareForBackward.
Expand Down Expand Up @@ -457,7 +457,7 @@ class LSTMTrainer : public LSTMRecognizer {
GenericVector<double> best_error_history_;
GenericVector<int> best_error_iterations_;
// Number of iterations since the best_error_rate_ was 2% more than it is now.
int improvement_steps_;
int32_t improvement_steps_;
// Number of iterations that yielded a non-zero delta error and thus provided
// significant learning. learning_iteration_ <= training_iteration_.
// learning_iteration_ is used to measure rate of learning progress.
Expand Down

0 comments on commit b7b8dba

Please sign in to comment.