Skip to content

Commit

Permalink
LSTM char_whitelist/blacklist (6ac2ff0): multi-code chars
Browse files Browse the repository at this point in the history
- move decision from ComputeTopN to ContinueContext, where
  it belongs: block context continuations which emit final
  codes translating to disabled unichar_ids.
  (The normal logic for fallback from top2 > top2 > rest
   will apply.)
- pass UNICHARSET refs appropriately
  • Loading branch information
bertsky committed Mar 8, 2019
1 parent 8012d5e commit b459990
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
30 changes: 13 additions & 17 deletions src/lstm/recodebeam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
if (lstm_choice_mode)
timesteps.clear();
for (int t = 0; t < width; ++t) {
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0], charset);
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
charset);
if (lstm_choice_mode) {
Expand All @@ -102,7 +102,7 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
beam_size_ = 0;
int width = output.dim1();
for (int t = 0; t < width; ++t) {
ComputeTopN(output[t], output.dim2(), kBeamWidths[0], charset);
ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
}
}
Expand Down Expand Up @@ -456,19 +456,12 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
// Fills top_n_flags_ with bools that are true iff the corresponding output
// is one of the top_n.
void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
int top_n, const UNICHARSET* charset) {
int top_n) {
top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
top_code_ = -1;
second_code_ = -1;
top_heap_.clear();
for (int i = 0; i < num_outputs; ++i) {
// Decode label via recoder_.
RecodedCharID code;
code.Set(0, i);
int label = recoder_.DecodeUnichar(code);
if (label != INVALID_UNICHAR_ID && // not part of a bigger code.
!charset->get_enabled(label)) // disabled
continue;
if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
TopPair entry(outputs[i], i);
top_heap_.Push(&entry);
Expand Down Expand Up @@ -505,10 +498,10 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
if (t == 0) {
// The first step can only use singles and initials.
ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
dict_ratio, cert_offset, worst_dict_cert, step);
charset, dict_ratio, cert_offset, worst_dict_cert, step);
if (dict_ != nullptr) {
ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
TN_TOP2, dict_ratio, cert_offset, worst_dict_cert, step);
ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, TN_TOP2,
charset, dict_ratio, cert_offset, worst_dict_cert, step);
}
} else {
RecodeBeam* prev = beam_[t - 1];
Expand Down Expand Up @@ -540,9 +533,8 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
// best first, but it comes before a lot of the worst, so it is slightly
// more efficient than going forwards.
for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
ContinueContext(&prev->beams_[index].get(i).data, index, outputs,
top_n, dict_ratio, cert_offset, worst_dict_cert,
step);
ContinueContext(&prev->beams_[index].get(i).data, index, outputs, top_n,
charset, dict_ratio, cert_offset, worst_dict_cert, step);
}
}
for (int index = 0; index < kNumBeams; ++index) {
Expand All @@ -569,7 +561,9 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
// choices for which top_n_flags[index] == top_n_flag.
void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
const float* outputs,
TopNState top_n_flag, double dict_ratio,
TopNState top_n_flag,
const UNICHARSET* charset,
double dict_ratio,
double cert_offset,
double worst_dict_cert,
RecodeBeam* step) {
Expand Down Expand Up @@ -632,6 +626,8 @@ void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
int unichar_id = recoder_.DecodeUnichar(full_code);
// Map the null char to INVALID.
if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
if (unichar_id != INVALID_UNICHAR_ID && !charset->get_enabled(unichar_id))
continue; // disabled by whitelist/blacklist
ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
use_dawgs, NC_ANYTHING, prev, step);
if (top_n_flag == TN_TOP2 && code != null_char_) {
Expand Down
8 changes: 4 additions & 4 deletions src/lstm/recodebeam.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class RecodeBeamSearch {

// Fills top_n_flags_ with bools that are true iff the corresponding output
// is one of the top_n.
void ComputeTopN(const float* outputs, int num_outputs, int top_n, const UNICHARSET* unicharset);
void ComputeTopN(const float* outputs, int num_outputs, int top_n);

// Adds the computation for the current time-step to the beam. Call at each
// time-step in sequence from left to right. outputs is the activation vector
Expand All @@ -310,9 +310,9 @@ class RecodeBeamSearch {
// using the given network outputs to provide scores to the choices. Uses only
// those choices for which top_n_flags[code] == top_n_flag.
void ContinueContext(const RecodeNode* prev, int index, const float* outputs,
TopNState top_n_flag, double dict_ratio,
double cert_offset, double worst_dict_cert,
RecodeBeam* step);
TopNState top_n_flag, const UNICHARSET* unicharset,
double dict_ratio, double cert_offset,
double worst_dict_cert, RecodeBeam* step);
// Continues for a new unichar, using dawg or non-dawg as per flag.
void ContinueUnichar(int code, int unichar_id, float cert,
float worst_dict_cert, float dict_ratio, bool use_dawgs,
Expand Down

0 comments on commit b459990

Please sign in to comment.