Skip to content

Commit

Permalink
Corrected SetEnableTraining for recovery from a recognize-only model.
Browse files Browse the repository at this point in the history
  • Loading branch information
theraysmith committed May 5, 2017
1 parent 006a56c commit 4fa463c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 16 deletions.
12 changes: 8 additions & 4 deletions lstm/fullyconnected.cpp
Expand Up @@ -56,13 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
return result;
}

// Suspends/Enables training by setting the training_ flag. Serialize and
// DeSerialize only operate on the run-time data if state is false.
// Suspends/Enables training by setting the training_ flag.
void FullyConnected::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) weights_.InitBackward(false);
training_ = TS_ENABLED;
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED)
weights_.InitBackward();
training_ = state;
}
}
Expand Down
12 changes: 8 additions & 4 deletions lstm/lstm.cpp
Expand Up @@ -107,14 +107,18 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
// DeSerialize only operate on the run-time data if state is false.
void LSTM::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
if (training_ == TS_DISABLED) {
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
if (state == TS_ENABLED && training_ == TS_DISABLED) {
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
gate_weights_[w].InitBackward(false);
gate_weights_[w].InitBackward();
}
}
training_ = TS_ENABLED;
} else {
training_ = state;
}
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
Expand Down
6 changes: 5 additions & 1 deletion lstm/network.cpp
Expand Up @@ -111,7 +111,11 @@ Network::~Network() {
// recognizer can be converted back to a trainer.
void Network::SetEnableTraining(TrainingState state) {
if (state == TS_RE_ENABLE) {
training_ = TS_ENABLED;
// Enable only from temp disabled.
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
} else if (state == TS_TEMP_DISABLE) {
// Temp disable only from enabled.
if (training_ == TS_ENABLED) training_ = state;
} else {
training_ = state;
}
Expand Down
3 changes: 2 additions & 1 deletion lstm/network.h
Expand Up @@ -93,9 +93,10 @@ enum TrainingState {
// Valid states of training_.
TS_DISABLED, // Disabled permanently.
TS_ENABLED, // Enabled for backprop and to write a training dump.
// Re-enable from ANY disabled state.
TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
// Valid only for SetEnableTraining.
TS_RE_ENABLE, // Re-Enable whatever the current state.
TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
};

// Base class for network types. Not quite an abstract base class, but almost.
Expand Down
10 changes: 5 additions & 5 deletions lstm/weightmatrix.cpp
Expand Up @@ -47,7 +47,8 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
}
}
}
InitBackward(ada_grad);
use_ada_grad_ = ada_grad;
InitBackward();
return ni * no;
}

Expand Down Expand Up @@ -83,10 +84,9 @@ void WeightMatrix::ConvertToInt() {

// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void WeightMatrix::InitBackward(bool ada_grad) {
void WeightMatrix::InitBackward() {
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
use_ada_grad_ = ada_grad;
dw_.Resize(no, ni, 0.0);
updates_.Resize(no, ni, 0.0);
wf_t_.Transpose(wf_);
Expand Down Expand Up @@ -134,7 +134,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
} else {
if (!wf_.DeSerialize(fp)) return false;
if (training) {
InitBackward(use_ada_grad_);
InitBackward();
if (!updates_.DeSerialize(fp)) return false;
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false;
}
Expand All @@ -157,7 +157,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
FloatToDouble(float_array, &wf_);
}
if (training) {
InitBackward(use_ada_grad_);
InitBackward();
if (!float_array.DeSerialize(fp)) return false;
FloatToDouble(float_array, &updates_);
// Errs was only used in int training, which is now dead.
Expand Down
2 changes: 1 addition & 1 deletion lstm/weightmatrix.h
Expand Up @@ -92,7 +92,7 @@ class WeightMatrix {

// Allocates any needed memory for running Backward, and zeroes the deltas,
// thus eliminating any existing momentum.
void InitBackward(bool ada_grad);
void InitBackward();

// Writes to the given file. Returns false in case of error.
bool Serialize(bool training, TFile* fp) const;
Expand Down

0 comments on commit 4fa463c

Please sign in to comment.