Skip to content

Commit

Permalink
Merge 4693e5a into 4850eee
Browse files Browse the repository at this point in the history
  • Loading branch information
dancingpipi committed Dec 19, 2018
2 parents 4850eee + 4693e5a commit e2fc8e3
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions src/layer/lstm.cpp
Expand Up @@ -38,25 +38,28 @@ int LSTM::load_model(const ModelBin& mb)
int size = weight_data_size / num_output / 4;

// raw weight data
weight_hc_data = mb.load(size, num_output * 4, 0);
if (weight_hc_data.empty())
return -100;

weight_xc_data = mb.load(size, num_output * 4, 0);
if (weight_xc_data.empty())
return -100;


bias_c_data = mb.load(4, num_output, 0);
if (bias_c_data.empty())
return -100;

weight_hc_data = mb.load(num_output, num_output * 4, 0);
if (weight_hc_data.empty())
return -100;

return 0;
}

int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
fprintf(stdout, "lstm forward start!\n");
// size x T
const Mat& input_blob = bottom_blobs[0];

size_t elemsize = input_blob.elemsize;

// T, 0 or 1 each
Expand Down Expand Up @@ -98,31 +101,44 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
const float* x = input_blob.row(t);
for (int q=0; q<num_output; q++)
{
float h_cont = cont ? hidden[q] : 0.f;
//float h_cont = cont ? hidden[q] : 0.f;

const float* I_bias_c_data_ptr = (const float*)bias_c_data;
const float* F_bias_c_data_ptr = (const float*)bias_c_data + num_output;
const float* O_bias_c_data_ptr = (const float*)bias_c_data + 2 * num_output;
const float* G_bias_c_data_ptr = (const float*)bias_c_data + 3 * num_output;

const float* bias_c_data_ptr = (const float*)bias_c_data + 4 * q;
//const float* bias_c_data_ptr = (const float*)bias_c_data + 4 * q;
float* gates_data = (float*)gates + 4 * q;

// gate I F O G
const float* weight_hc_data_I = (const float*)weight_hc_data + weight_hc_data.w * q;
const float* weight_xc_data_I = (const float*)weight_xc_data + weight_xc_data.w * q;
const float* weight_hc_data_F = (const float*)weight_hc_data + weight_hc_data.w * q + size;
const float* weight_xc_data_F = (const float*)weight_xc_data + weight_xc_data.w * q + size;
const float* weight_hc_data_O = (const float*)weight_hc_data + weight_hc_data.w * q + size*2;
const float* weight_xc_data_O = (const float*)weight_xc_data + weight_xc_data.w * q + size*2;
const float* weight_hc_data_G = (const float*)weight_hc_data + weight_hc_data.w * q + size*3;
const float* weight_xc_data_G = (const float*)weight_xc_data + weight_xc_data.w * q + size*3;

float I = bias_c_data_ptr[0];
float F = bias_c_data_ptr[1];
float O = bias_c_data_ptr[2];
float G = bias_c_data_ptr[3];
const float* weight_hc_data_F = (const float*)weight_hc_data + weight_hc_data.w * q + num_output * num_output;
const float* weight_xc_data_F = (const float*)weight_xc_data + weight_xc_data.w * q + num_output * size;
const float* weight_hc_data_O = (const float*)weight_hc_data + weight_hc_data.w * q + num_output * num_output * 2;
const float* weight_xc_data_O = (const float*)weight_xc_data + weight_xc_data.w * q + num_output * size * 2;
const float* weight_hc_data_G = (const float*)weight_hc_data + weight_hc_data.w * q + num_output * num_output * 3;
const float* weight_xc_data_G = (const float*)weight_xc_data + weight_xc_data.w * q + num_output * size * 3;

float I = I_bias_c_data_ptr[q];
float F = F_bias_c_data_ptr[q];
float O = O_bias_c_data_ptr[q];
float G = G_bias_c_data_ptr[q];

for (int i=0; i<size; i++)
{
I += weight_hc_data_I[i] * h_cont + weight_xc_data_I[i] * x[i];
F += weight_hc_data_F[i] * h_cont + weight_xc_data_F[i] * x[i];
O += weight_hc_data_O[i] * h_cont + weight_xc_data_O[i] * x[i];
G += weight_hc_data_G[i] * h_cont + weight_xc_data_G[i] * x[i];
I += weight_xc_data_I[i] * x[i];
F += weight_xc_data_F[i] * x[i];
O += weight_xc_data_O[i] * x[i];
G += weight_xc_data_G[i] * x[i];
}

for (int i=0; i<num_output; ++i){
I += weight_hc_data_I[i] * (cont == 0? 0: hidden[i]);
F += weight_hc_data_F[i] * (cont == 0? 0: hidden[i]);
O += weight_hc_data_O[i] * (cont == 0? 0: hidden[i]);
G += weight_hc_data_G[i] * (cont == 0? 0: hidden[i]);
}

gates_data[0] = I;
Expand Down Expand Up @@ -163,6 +179,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl

// no cell output here
}
fprintf(stdout, "lstm forward end!\n");

return 0;
}
Expand Down

0 comments on commit e2fc8e3

Please sign in to comment.