Skip to content

Commit

Permalink
LSTMTrainer: Catch empty vectors
Browse files Browse the repository at this point in the history
The new test in LSTMTrainer::UpdateErrorGraph fixes an assertion
(see issues #644, #792).

The new test in LSTMTrainer::ReadTrainingDump was added to improve
the robustness of the code.

Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jun 4, 2017
1 parent 1e5522d commit 34d1e73
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,10 @@ bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount,
// Reads previously saved trainer from memory.
bool LSTMTrainer::ReadTrainingDump(const GenericVector<char>& data,
LSTMTrainer* trainer) {
if (data.size() == 0) {
tprintf("Warning: data size is zero in LSTMTrainer::ReadTrainingDump\n");
return false;
}
return trainer->ReadSizedTrainingDump(&data[0], data.size());
}

Expand Down Expand Up @@ -1298,8 +1302,9 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
if (error_rate < best_error_rate_) {
// This is a new (global) minimum.
if (tester != NULL) {
result = tester->Run(worst_iteration_, worst_error_rates_,
worst_model_data_, CurrentTrainingStage());
if (worst_model_data_.size() != 0)
result = tester->Run(worst_iteration_, worst_error_rates_,
worst_model_data_, CurrentTrainingStage());
worst_model_data_.truncate(0);
best_model_data_ = model_data;
}
Expand Down

0 comments on commit 34d1e73

Please sign in to comment.