From b86b4fa06ba4d2afa00c53470a19f6630e638f66 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Mon, 8 May 2017 14:26:09 -0700 Subject: [PATCH] Better fix for re-enabling training --- lstm/fullyconnected.cpp | 2 +- lstm/lstm.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lstm/fullyconnected.cpp b/lstm/fullyconnected.cpp index e73add6cd4..f91a4fe156 100644 --- a/lstm/fullyconnected.cpp +++ b/lstm/fullyconnected.cpp @@ -65,7 +65,7 @@ void FullyConnected::SetEnableTraining(TrainingState state) { // Temp disable only from enabled. if (training_ == TS_ENABLED) training_ = state; } else { - if (state == TS_ENABLED && training_ == TS_DISABLED) + if (state == TS_ENABLED && training_ != TS_ENABLED) weights_.InitBackward(); training_ = state; } diff --git a/lstm/lstm.cpp b/lstm/lstm.cpp index f94cebee4e..3864153bb5 100644 --- a/lstm/lstm.cpp +++ b/lstm/lstm.cpp @@ -113,7 +113,7 @@ void LSTM::SetEnableTraining(TrainingState state) { // Temp disable only from enabled. if (training_ == TS_ENABLED) training_ = state; } else { - if (state == TS_ENABLED && training_ == TS_DISABLED) { + if (state == TS_ENABLED && training_ != TS_ENABLED) { for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; gate_weights_[w].InitBackward();