diff --git a/lstm/fullyconnected.cpp b/lstm/fullyconnected.cpp index 52c0cbf36b..ea368ca223 100644 --- a/lstm/fullyconnected.cpp +++ b/lstm/fullyconnected.cpp @@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input, int thread_id = 0; #endif double* temp_line = temp_lines[thread_id]; - const double* d_input = nullptr; - const int8_t* i_input = nullptr; if (input.int_mode()) { - i_input = input.i(t); + ForwardTimeStep(input.i(t), t, temp_line); } else { input.ReadTimeStep(t, curr_input[thread_id]); - d_input = curr_input[thread_id]; + ForwardTimeStep(curr_input[thread_id], t, temp_line); } - ForwardTimeStep(d_input, i_input, t, temp_line); output->WriteTimeStep(t, temp_line); if (IsTraining() && type_ != NT_SOFTMAX) { acts_.CopyTimeStepFrom(t, *output, t); @@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input, } } -void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input, - int t, double* output_line) { - // input is copied to source_ line-by-line for cache coherency. - if (IsTraining() && external_source_ == nullptr && d_input != nullptr) - source_t_.WriteStrided(t, d_input); - if (d_input != nullptr) - weights_.MatrixDotVector(d_input, output_line); - else - weights_.MatrixDotVector(i_input, output_line); +void FullyConnected::ForwardTimeStep(int t, double* output_line) { if (type_ == NT_TANH) { FuncInplace(no_, output_line); } else if (type_ == NT_LOGISTIC) { @@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu } } +void FullyConnected::ForwardTimeStep(const double* d_input, + int t, double* output_line) { + // input is copied to source_ line-by-line for cache coherency. + if (IsTraining() && external_source_ == NULL) + source_t_.WriteStrided(t, d_input); + weights_.MatrixDotVector(d_input, output_line); + ForwardTimeStep(t, output_line); +} + +void FullyConnected::ForwardTimeStep(const int8_t* i_input, + int t, double* output_line) { + // input is copied to source_ line-by-line for cache coherency. + weights_.MatrixDotVector(i_input, output_line); + ForwardTimeStep(t, output_line); +} + // Runs backward propagation of errors on the deltas line. // See NetworkCpp for a detailed discussion of the arguments. bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas, diff --git a/lstm/fullyconnected.h b/lstm/fullyconnected.h index 6b9b22a9eb..2c886f9fd0 100644 --- a/lstm/fullyconnected.h +++ b/lstm/fullyconnected.h @@ -91,8 +91,9 @@ class FullyConnected : public Network { // Components of Forward so FullyConnected can be reused inside LSTM. void SetupForward(const NetworkIO& input, const TransposedArray* input_transpose); - void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t, - double* output_line); + void ForwardTimeStep(int t, double* output_line); + void ForwardTimeStep(const double* d_input, int t, double* output_line); + void ForwardTimeStep(const int8_t* i_input, int t, double* output_line); // Runs backward propagation of errors on the deltas line. // See Network for a detailed discussion of the arguments. diff --git a/lstm/lstm.cpp b/lstm/lstm.cpp index 516ad0ffae..f4b81ba0a4 100644 --- a/lstm/lstm.cpp +++ b/lstm/lstm.cpp @@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input, if (softmax_ != nullptr) { if (input.int_mode()) { int_output->WriteTimeStepPart(0, 0, ns_, curr_output); - softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output); + softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output); } else { - softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output); + softmax_->ForwardTimeStep(curr_output, t, softmax_output); } output->WriteTimeStep(t, softmax_output); if (type_ == NT_LSTM_SOFTMAX_ENCODED) {