Skip to content

Commit

Permalink
Merge pull request #3 from reyoung/pr/4929
Browse files Browse the repository at this point in the history
Several Enhancement
  • Loading branch information
qingqing01 committed Oct 23, 2017
2 parents 694bc64 + 65906ef commit 34aac18
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 97 deletions.
16 changes: 8 additions & 8 deletions paddle/operators/lstm_op.cc
Expand Up @@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel {
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection",
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
Expand All @@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 4D), where, T is the "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
Expand All @@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate aftern the nonlinear computation. This "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
Expand All @@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false);
AddAttr<std::string>(
"gateActivation",
"(string, defalut: sigmoid)"
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by defalut.")
"gate, `sigmoid` by default.")
.SetDefault("sigmoid");
AddAttr<std::string>("cellActivation",
"(string, defalut: tanh)"
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh");
AddAttr<std::string>("candidateActivation",
"(string, defalut: tanh)"
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by defalut.")
"`tanh` by default.")
.SetDefault("tanh");
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
Expand Down
18 changes: 9 additions & 9 deletions paddle/operators/lstm_op.h
Expand Up @@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> {
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);

auto in_dims = input->dims();
int frame_size = in_dims[1] / 4;
int frame_size = static_cast<int>(in_dims[1] / 4);
framework::DDim dims({in_dims[0], frame_size});

if (bias) {
Expand All @@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> {

math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>());
// the code styple in LstmMetaValue will be updated later.
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
Expand All @@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> {
framework::LoDTensor batch_cell_pre_act;
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());

auto batch_lod = batch_gate->lod()[0];
int num_batch = batch_lod.size() - 1;
auto& batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");

for (int n = 0; n < num_batch; n++) {
int bstart = batch_lod[n];
int bend = batch_lod[n + 1];
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);

Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
Tensor out_t = batch_out.Slice<T>(bstart, bend);
Expand All @@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart;

if (n != 0) {
int pre_h_start = batch_lod[n - 1];
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : support the initial hidden and cell
// else if : FIXME support the initial hidden and cell

lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
Expand Down
83 changes: 42 additions & 41 deletions paddle/operators/math/detail/lstm_kernel.h
Expand Up @@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"

#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
#include <type_traits>

namespace paddle {
namespace operators {
Expand All @@ -30,12 +27,12 @@ namespace forward {
template <class T>
class lstm {
public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
Expand All @@ -45,17 +42,19 @@ class lstm {
output = valueOg * stateAtv;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
static const bool avx = true;
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
hppl::Active<__m256>::forward actInput,
hppl::Active<__m256>::forward actGate,
hppl::Active<__m256>::forward actState) {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;

HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
hppl::Active<__m256>::forward actInput,
hppl::Active<__m256>::forward actGate,
hppl::Active<__m256>::forward actState) {
valueIn = actInput(valueIn);
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
Expand All @@ -76,14 +75,15 @@ namespace backward {
template <class T>
class lstm {
public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state, T &stateGrad,
T &stateAtv, T &outputGrad, T &checkI, T &checkF,
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
Expand All @@ -95,21 +95,22 @@ class lstm {
checkOGrad = gradOg * state;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
static const bool avx = true;
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
__m256 &checkOGrad,
hppl::Active<__m256>::backward actInput,
hppl::Active<__m256>::backward actGate,
hppl::Active<__m256>::backward actState) {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad,
hppl::Active<__m256>::backward actInput,
hppl::Active<__m256>::backward actGate,
hppl::Active<__m256>::backward actState) {
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
stateGrad = _mm256_add_ps(
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
Expand Down
9 changes: 5 additions & 4 deletions paddle/operators/math/lstm_compute.cc
Expand Up @@ -24,8 +24,8 @@ template <class T>
struct LstmUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act) {
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
Expand All @@ -45,8 +45,9 @@ template <class T>
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act) {
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act),
Expand Down
9 changes: 5 additions & 4 deletions paddle/operators/math/lstm_compute.cu
Expand Up @@ -24,8 +24,8 @@ template <class T>
struct LstmUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act) {
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
Expand All @@ -36,8 +36,9 @@ template <class T>
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act) {
int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
Expand Down
9 changes: 5 additions & 4 deletions paddle/operators/math/lstm_compute.h
Expand Up @@ -72,17 +72,18 @@ class LstmUnitFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act);
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};

template <typename Place, typename T>
class LstmUnitGradFunctor {
public:
static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act);
int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
};

} // namespace math
Expand Down
2 changes: 0 additions & 2 deletions paddle/operators/math/sequence2batch.cc
Expand Up @@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;

template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/math/sequence2batch.cu
Expand Up @@ -21,7 +21,7 @@ namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
int64_t height, int64_t width,
const bool is_src_index) {
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * GridDimX;
Expand Down

0 comments on commit 34aac18

Please sign in to comment.