Skip to content

Commit

Permalink
Added ability to randomly rotate images upside-down during training f…
Browse files Browse the repository at this point in the history
…or training OSD
  • Loading branch information
Ray Smith committed Sep 8, 2017
1 parent 3e63918 commit 4cf123e
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
7 changes: 4 additions & 3 deletions lstm/lstmrecognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
NetworkIO outputs;
float scale_factor;
NetworkIO inputs;
if (!RecognizeLine(image_data, invert, debug, false, &scale_factor, &inputs,
&outputs))
if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor,
&inputs, &outputs))
return;
if (search_ == NULL) {
search_ =
Expand Down Expand Up @@ -227,7 +227,7 @@ void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
bool debug, bool re_invert,
bool debug, bool re_invert, bool upside_down,
float* scale_factor, NetworkIO* inputs,
NetworkIO* outputs) {
// Maximum width of image to train on.
Expand All @@ -247,6 +247,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
pixDestroy(&pix);
return false;
}
if (upside_down) pixRotate180(pix, pix);
// Reduction factor from image to coords.
*scale_factor = min_width / *scale_factor;
inputs->set_int_mode(IsIntMode());
Expand Down
4 changes: 2 additions & 2 deletions lstm/lstmrecognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ class LSTMRecognizer {
// forward outputs for the best photometric interpretation.
// inputs is filled with the used inputs to the network.
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
bool re_invert, float* scale_factor, NetworkIO* inputs,
NetworkIO* outputs);
bool re_invert, bool upside_down, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs);

// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
Expand Down
31 changes: 26 additions & 5 deletions lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ const int kTargetXScale = 5;
const int kTargetYScale = 100;

LSTMTrainer::LSTMTrainer()
: training_data_(0),
: randomly_rotate_(false),
training_data_(0),
file_reader_(LoadDataFromFile),
file_writer_(SaveDataToFile),
checkpoint_reader_(
Expand All @@ -88,7 +89,8 @@ LSTMTrainer::LSTMTrainer(FileReader file_reader, FileWriter file_writer,
CheckPointWriter checkpoint_writer,
const char* model_base, const char* checkpoint_name,
int debug_interval, inT64 max_memory)
: training_data_(max_memory),
: randomly_rotate_(false),
training_data_(max_memory),
file_reader_(file_reader),
file_writer_(file_writer),
checkpoint_reader_(checkpoint_reader),
Expand Down Expand Up @@ -296,7 +298,9 @@ void LSTMTrainer::DebugNetwork() {
// tesseract into memory ready for training. Returns false if nothing was
// loaded.
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
CachingStrategy cache_strategy) {
CachingStrategy cache_strategy,
bool randomly_rotate) {
randomly_rotate_ = randomly_rotate;
training_data_.Clear();
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
}
Expand Down Expand Up @@ -838,6 +842,23 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
trainingdata->language().string());
return UNENCODABLE;
}
bool upside_down = false;
if (randomly_rotate_) {
// This ensures consistent training results.
SetRandomSeed();
upside_down = randomizer_.SignedRand(1.0) > 0.0;
if (upside_down) {
// Modify the truth labels to match the rotation:
// Apart from space and null, increment the label. This is changes the
// script-id to the same script-id but upside-down.
// The labels need to be reversed in order, as the first is now the last.
for (int c = 0; c < truth_labels.size(); ++c) {
if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
++truth_labels[c];
}
truth_labels.reverse();
}
}
int w = 0;
while (w < truth_labels.size() &&
(truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
Expand All @@ -850,8 +871,8 @@ Trainability LSTMTrainer::PrepareForBackward(const ImageData* trainingdata,
float image_scale;
NetworkIO inputs;
bool invert = trainingdata->boxes().empty();
if (!RecognizeLine(*trainingdata, invert, debug, invert, &image_scale,
&inputs, fwd_outputs)) {
if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
&image_scale, &inputs, fwd_outputs)) {
tprintf("Image not trainable\n");
return UNENCODABLE;
}
Expand Down
4 changes: 3 additions & 1 deletion lstm/lstmtrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class LSTMTrainer : public LSTMRecognizer {
// tesseract into memory ready for training. Returns false if nothing was
// loaded.
bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
CachingStrategy cache_strategy);
CachingStrategy cache_strategy,
bool randomly_rotate);

// Keeps track of best and locally worst error rate, using internally computed
// values. See MaintainCheckpointsSpecific for more detail.
Expand Down Expand Up @@ -409,6 +410,7 @@ class LSTMTrainer : public LSTMRecognizer {
// Checkpoint filename.
STRING checkpoint_name_;
// Training data.
bool randomly_rotate_;
DocumentCache training_data_;
// Name to use when saving best_trainer_.
STRING best_model_name_;
Expand Down
10 changes: 7 additions & 3 deletions training/lstmtraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ STRING_PARAM_FLAG(traineddata, "",
STRING_PARAM_FLAG(old_traineddata, "",
"When changing the character set, this specifies the old"
" character set that is to be replaced");
BOOL_PARAM_FLAG(randomly_rotate, false,
"Train OSD and randomly turn training samples upside-down");

// Number of training images to train between calls to MaintainCheckpoints.
const int kNumPagesPerBatch = 100;
Expand Down Expand Up @@ -167,9 +169,11 @@ int main(int argc, char **argv) {
trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
}
}
if (!trainer.LoadAllTrainingData(
filenames, FLAGS_sequential_training ? tesseract::CS_SEQUENTIAL
: tesseract::CS_ROUND_ROBIN)) {
if (!trainer.LoadAllTrainingData(filenames,
FLAGS_sequential_training
? tesseract::CS_SEQUENTIAL
: tesseract::CS_ROUND_ROBIN,
FLAGS_randomly_rotate)) {
tprintf("Load of images failed!!\n");
return 1;
}
Expand Down

0 comments on commit 4cf123e

Please sign in to comment.