From 4fa463cd71366c2854090b38f25f94f3765d54b0 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Fri, 5 May 2017 16:39:43 -0700 Subject: [PATCH] Corrected SetEnableTraining for recovery from a recognize-only model. --- lstm/fullyconnected.cpp | 12 ++++++++---- lstm/lstm.cpp | 12 ++++++++---- lstm/network.cpp | 6 +++++- lstm/network.h | 3 ++- lstm/weightmatrix.cpp | 10 +++++----- lstm/weightmatrix.h | 2 +- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/lstm/fullyconnected.cpp b/lstm/fullyconnected.cpp index ecf43db192..e73add6cd4 100644 --- a/lstm/fullyconnected.cpp +++ b/lstm/fullyconnected.cpp @@ -56,13 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const { return result; } -// Suspends/Enables training by setting the training_ flag. Serialize and -// DeSerialize only operate on the run-time data if state is false. +// Suspends/Enables training by setting the training_ flag. void FullyConnected::SetEnableTraining(TrainingState state) { if (state == TS_RE_ENABLE) { - if (training_ == TS_DISABLED) weights_.InitBackward(false); - training_ = TS_ENABLED; + // Enable only from temp disabled. + if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED; + } else if (state == TS_TEMP_DISABLE) { + // Temp disable only from enabled. + if (training_ == TS_ENABLED) training_ = state; } else { + if (state == TS_ENABLED && training_ == TS_DISABLED) + weights_.InitBackward(); training_ = state; } } diff --git a/lstm/lstm.cpp b/lstm/lstm.cpp index 3b9ca87c2f..f94cebee4e 100644 --- a/lstm/lstm.cpp +++ b/lstm/lstm.cpp @@ -107,14 +107,18 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const { // DeSerialize only operate on the run-time data if state is false. void LSTM::SetEnableTraining(TrainingState state) { if (state == TS_RE_ENABLE) { - if (training_ == TS_DISABLED) { + // Enable only from temp disabled. + if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED; + } else if (state == TS_TEMP_DISABLE) { + // Temp disable only from enabled. + if (training_ == TS_ENABLED) training_ = state; + } else { + if (state == TS_ENABLED && training_ == TS_DISABLED) { for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; - gate_weights_[w].InitBackward(false); + gate_weights_[w].InitBackward(); } } - training_ = TS_ENABLED; - } else { training_ = state; } if (softmax_ != NULL) softmax_->SetEnableTraining(state); diff --git a/lstm/network.cpp b/lstm/network.cpp index ee3289e247..a7140a4731 100644 --- a/lstm/network.cpp +++ b/lstm/network.cpp @@ -111,7 +111,11 @@ Network::~Network() { // recognizer can be converted back to a trainer. void Network::SetEnableTraining(TrainingState state) { if (state == TS_RE_ENABLE) { - training_ = TS_ENABLED; + // Enable only from temp disabled. + if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED; + } else if (state == TS_TEMP_DISABLE) { + // Temp disable only from enabled. + if (training_ == TS_ENABLED) training_ = state; } else { training_ = state; } diff --git a/lstm/network.h b/lstm/network.h index 951af3fb30..ca0e306f25 100644 --- a/lstm/network.h +++ b/lstm/network.h @@ -93,9 +93,10 @@ enum TrainingState { // Valid states of training_. TS_DISABLED, // Disabled permanently. TS_ENABLED, // Enabled for backprop and to write a training dump. + // Re-enable from ANY disabled state. TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump. // Valid only for SetEnableTraining. - TS_RE_ENABLE, // Re-Enable whatever the current state. + TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED. }; // Base class for network types. Not quite an abstract base class, but almost. diff --git a/lstm/weightmatrix.cpp b/lstm/weightmatrix.cpp index 77b8d824f4..1241665ee7 100644 --- a/lstm/weightmatrix.cpp +++ b/lstm/weightmatrix.cpp @@ -47,7 +47,8 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad, } } } - InitBackward(ada_grad); + use_ada_grad_ = ada_grad; + InitBackward(); return ni * no; } @@ -83,10 +84,9 @@ void WeightMatrix::ConvertToInt() { // Allocates any needed memory for running Backward, and zeroes the deltas, // thus eliminating any existing momentum. -void WeightMatrix::InitBackward(bool ada_grad) { +void WeightMatrix::InitBackward() { int no = int_mode_ ? wi_.dim1() : wf_.dim1(); int ni = int_mode_ ? wi_.dim2() : wf_.dim2(); - use_ada_grad_ = ada_grad; dw_.Resize(no, ni, 0.0); updates_.Resize(no, ni, 0.0); wf_t_.Transpose(wf_); @@ -134,7 +134,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) { } else { if (!wf_.DeSerialize(fp)) return false; if (training) { - InitBackward(use_ada_grad_); + InitBackward(); if (!updates_.DeSerialize(fp)) return false; if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false; } @@ -157,7 +157,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) { FloatToDouble(float_array, &wf_); } if (training) { - InitBackward(use_ada_grad_); + InitBackward(); if (!float_array.DeSerialize(fp)) return false; FloatToDouble(float_array, &updates_); // Errs was only used in int training, which is now dead. diff --git a/lstm/weightmatrix.h b/lstm/weightmatrix.h index e1b04c37d6..76dea68f7c 100644 --- a/lstm/weightmatrix.h +++ b/lstm/weightmatrix.h @@ -92,7 +92,7 @@ class WeightMatrix { // Allocates any needed memory for running Backward, and zeroes the deltas, // thus eliminating any existing momentum. - void InitBackward(bool ada_grad); + void InitBackward(); // Writes to the given file. Returns false in case of error. bool Serialize(bool training, TFile* fp) const;