Skip to content

Commit

Permalink
Fixes to training process to allow incremental training from a recogn…
Browse files Browse the repository at this point in the history
…ition model
  • Loading branch information
theraysmith committed Nov 30, 2016
1 parent 9d90567 commit ce76d1c
Show file tree
Hide file tree
Showing 31 changed files with 650 additions and 122 deletions.
10 changes: 8 additions & 2 deletions ccmain/linerec.cpp
Expand Up @@ -64,6 +64,7 @@ void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
return;
}
TrainFromBoxes(boxes, texts, block_list, &images);
images.Shuffle();
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
}
Expand All @@ -79,7 +80,10 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
int box_count = boxes.size();
// Process all the text lines in this page, as defined by the boxes.
int end_box = 0;
for (int start_box = 0; start_box < box_count; start_box = end_box) {
// Don't let \t, which marks newlines in the box file, get into the line
// content, as that makes the line unusable in training.
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
for (int start_box = end_box; start_box < box_count; start_box = end_box) {
// Find the textline of boxes starting at start and their bounding box.
TBOX line_box = boxes[start_box];
STRING line_str = texts[start_box];
Expand Down Expand Up @@ -115,7 +119,9 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
}
if (imagedata != NULL)
training_data->AddPageToDocument(imagedata);
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
// Don't let \t, which marks newlines in the box file, get into the line
// content, as that makes the line unusable in training.
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
}
}

Expand Down
2 changes: 2 additions & 0 deletions ccstruct/boxread.cpp
Expand Up @@ -55,6 +55,8 @@ bool ReadAllBoxes(int target_page, bool skip_blanks, const STRING& filename,
GenericVector<char> box_data;
if (!tesseract::LoadDataFromFile(BoxFileName(filename), &box_data))
return false;
// Convert the array of bytes to a string, so it can be used by the parser.
box_data.push_back('\0');
return ReadMemBoxes(target_page, skip_blanks, &box_data[0], boxes, texts,
box_texts, pages);
}
Expand Down
27 changes: 21 additions & 6 deletions ccstruct/imagedata.cpp
Expand Up @@ -24,18 +24,18 @@

#include "imagedata.h"

#if defined(__MINGW32__)
#include <unistd.h>
#else
#include <thread>
#endif

#include "allheaders.h"
#include "boxread.h"
#include "callcpp.h"
#include "helpers.h"
#include "tprintf.h"

#if defined(__MINGW32__)
# include <unistd.h>
#else
# include <thread>
#endif

// Number of documents to read ahead while training. Doesn't need to be very
// large.
const int kMaxReadAhead = 8;
Expand Down Expand Up @@ -496,6 +496,21 @@ inT64 DocumentData::UnCache() {
return memory_saved;
}

// Shuffles all the pages in the document.
void DocumentData::Shuffle() {
TRand random;
// Different documents get shuffled differently, but the same for the same
// name.
random.set_seed(document_name_.string());
int num_pages = pages_.size();
// Execute one random swap for each page in the document.
for (int i = 0; i < num_pages; ++i) {
int src = random.IntRand() % num_pages;
int dest = random.IntRand() % num_pages;
std::swap(pages_[src], pages_[dest]);
}
}

// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
// starting at index pages_offset_.
bool DocumentData::ReCachePages() {
Expand Down
2 changes: 2 additions & 0 deletions ccstruct/imagedata.h
Expand Up @@ -266,6 +266,8 @@ class DocumentData {
// Removes all pages from memory and frees the memory, but does not forget
// the document metadata. Returns the memory saved.
inT64 UnCache();
// Shuffles all the pages in the document.
void Shuffle();

private:
// Sets the value of total_pages_ behind a mutex.
Expand Down
13 changes: 6 additions & 7 deletions ccstruct/pageres.cpp
Expand Up @@ -529,13 +529,12 @@ void WERD_RES::FilterWordChoices(int debug_level) {
if (choice->unichar_id(i) != best_choice->unichar_id(j) &&
choice->certainty(i) - best_choice->certainty(j) < threshold) {
if (debug_level >= 2) {
STRING label;
label.add_str_int("\nDiscarding bad choice #", index);
choice->print(label.string());
tprintf("i %d j %d Chunk %d Choice->Blob[i].Certainty %.4g"
" BestChoice->ChunkCertainty[Chunk] %g Threshold %g\n",
i, j, chunk, choice->certainty(i),
best_choice->certainty(j), threshold);
choice->print("WorstCertaintyDiffWorseThan");
tprintf(
"i %d j %d Choice->Blob[i].Certainty %.4g"
" WorstOtherChoiceCertainty %g Threshold %g\n",
i, j, choice->certainty(i), best_choice->certainty(j), threshold);
tprintf("Discarding bad choice #%d\n", index);
}
delete it.extract();
break;
Expand Down
14 changes: 12 additions & 2 deletions ccutil/genericvector.h
Expand Up @@ -363,8 +363,7 @@ inline bool LoadDataFromFile(const STRING& filename,
fseek(fp, 0, SEEK_END);
size_t size = ftell(fp);
fseek(fp, 0, SEEK_SET);
// Pad with a 0, just in case we treat the result as a string.
data->init_to_size(static_cast<int>(size) + 1, 0);
data->init_to_size(static_cast<int>(size), 0);
bool result = fread(&(*data)[0], 1, size, fp) == size;
fclose(fp);
return result;
Expand All @@ -380,6 +379,17 @@ inline bool SaveDataToFile(const GenericVector<char>& data,
fclose(fp);
return result;
}
// Reads a file as a vector of STRING.
inline bool LoadFileLinesToStrings(const STRING& filename,
GenericVector<STRING>* lines) {
GenericVector<char> data;
if (!LoadDataFromFile(filename.string(), &data)) {
return false;
}
STRING lines_str(&data[0], data.size());
lines_str.split('\n', lines);
return true;
}

template <typename T>
bool cmp_eq(T const & t1, T const & t2) {
Expand Down
7 changes: 7 additions & 0 deletions ccutil/helpers.h
Expand Up @@ -27,6 +27,8 @@

#include <stdio.h>
#include <string.h>
#include <functional>
#include <string>

#include "host.h"

Expand All @@ -43,6 +45,11 @@ class TRand {
void set_seed(uinT64 seed) {
seed_ = seed;
}
// Sets the seed using a hash of a string.
void set_seed(const std::string& str) {
std::hash<std::string> hasher;
set_seed(static_cast<uinT64>(hasher(str)));
}

// Returns an integer in the range 0 to MAX_INT32.
inT32 IntRand() {
Expand Down
23 changes: 17 additions & 6 deletions lstm/fullyconnected.cpp
Expand Up @@ -56,6 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
return result;
}

// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void FullyConnected::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) weights_.InitBackward(false);
training_ = TS_ENABLED;
} else {
training_ = state;
}
}

// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int FullyConnected::InitWeights(float range, TRand* randomizer) {
Expand All @@ -78,14 +89,14 @@ void FullyConnected::DebugWeights() {
// Writes to the given file. Returns false in case of error.
bool FullyConnected::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (!weights_.Serialize(training_, fp)) return false;
if (!weights_.Serialize(IsTraining(), fp)) return false;
return true;
}

// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool FullyConnected::DeSerialize(bool swap, TFile* fp) {
if (!weights_.DeSerialize(training_, swap, fp)) return false;
if (!weights_.DeSerialize(IsTraining(), swap, fp)) return false;
return true;
}

Expand Down Expand Up @@ -129,14 +140,14 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
}
ForwardTimeStep(d_input, i_input, t, temp_line);
output->WriteTimeStep(t, temp_line);
if (training() && type_ != NT_SOFTMAX) {
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.CopyTimeStepFrom(t, *output, t);
}
}
// Zero all the elements that are in the padding around images that allows
// multiple different-sized images to exist in a single array.
// acts_ is only used if this is not a softmax op.
if (training() && type_ != NT_SOFTMAX) {
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.ZeroInvalidElements();
}
output->ZeroInvalidElements();
Expand All @@ -152,7 +163,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose) {
// Softmax output is always float, so save the input type.
int_mode_ = input.int_mode();
if (training()) {
if (IsTraining()) {
acts_.Resize(input, no_);
// Source_ is a transposed copy of input. It isn't needed if provided.
external_source_ = input_transpose;
Expand All @@ -163,7 +174,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (training() && external_source_ == NULL && d_input != NULL)
if (IsTraining() && external_source_ == NULL && d_input != NULL)
source_t_.WriteStrided(t, d_input);
if (d_input != NULL)
weights_.MatrixDotVector(d_input, output_line);
Expand Down
4 changes: 4 additions & 0 deletions lstm/fullyconnected.h
Expand Up @@ -61,6 +61,10 @@ class FullyConnected : public Network {
type_ = type;
}

// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(TrainingState state);

// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
Expand Down
27 changes: 22 additions & 5 deletions lstm/lstm.cpp
Expand Up @@ -102,6 +102,23 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
return result;
}

// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
void LSTM::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward(false);
}
}
training_ = TS_ENABLED;
} else {
training_ = state;
}
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
}

// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
int LSTM::InitWeights(float range, TRand* randomizer) {
Expand Down Expand Up @@ -148,7 +165,7 @@ bool LSTM::Serialize(TFile* fp) const {
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].Serialize(training_, fp)) return false;
if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
}
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
return true;
Expand All @@ -169,7 +186,7 @@ bool LSTM::DeSerialize(bool swap, TFile* fp) {
is_2d_ = false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false;
if (!gate_weights_[w].DeSerialize(IsTraining(), swap, fp)) return false;
if (w == CI) {
ns_ = gate_weights_[CI].NumOutputs();
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
Expand Down Expand Up @@ -322,7 +339,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
// Clip curr_state to a sane range.
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
if (training_) {
if (IsTraining()) {
// Save the gate node values.
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
Expand All @@ -331,7 +348,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
}
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
if (training_) state_.WriteTimeStep(t, curr_state);
if (IsTraining()) state_.WriteTimeStep(t, curr_state);
if (softmax_ != NULL) {
if (input.int_mode()) {
int_output->WriteTimeStep(0, curr_output);
Expand Down Expand Up @@ -697,7 +714,7 @@ void LSTM::PrintDW() {
void LSTM::ResizeForward(const NetworkIO& input) {
source_.Resize(input, na_);
which_fg_.ResizeNoInit(input.Width(), ns_);
if (training_) {
if (IsTraining()) {
state_.ResizeFloat(input, ns_);
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
Expand Down
4 changes: 4 additions & 0 deletions lstm/lstm.h
Expand Up @@ -69,6 +69,10 @@ class LSTM : public Network {
return spec;
}

// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
virtual void SetEnableTraining(TrainingState state);

// Sets up the network for training. Initializes weights using weights of
// scale `range` picked according to the random number generator `randomizer`.
virtual int InitWeights(float range, TRand* randomizer);
Expand Down
4 changes: 2 additions & 2 deletions lstm/lstmrecognizer.cpp
Expand Up @@ -253,7 +253,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs) {
// Maximum width of image to train on.
const int kMaxImageWidth = 2048;
const int kMaxImageWidth = 2560;
// This ensures consistent recognition results.
SetRandomSeed();
int min_width = network_->XScaleFactor();
Expand All @@ -263,7 +263,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
tprintf("Line cannot be recognized!!\n");
return false;
}
if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) {
if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
pixGetHeight(pix));
pixDestroy(&pix);
Expand Down

0 comments on commit ce76d1c

Please sign in to comment.