Skip to content

Commit

Permalink
implement lstm proj_size (#4263)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 14, 2022
1 parent 0f38cb2 commit 77eda4c
Show file tree
Hide file tree
Showing 14 changed files with 1,133 additions and 606 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yml
Expand Up @@ -227,7 +227,7 @@ jobs:
uses: actions/cache@v3
with:
path: lavapipe-install
key: lavapipe-linux-install-20211127-2
key: lavapipe-linux-install-20211127-3
- name: checkout-lavapipe
if: steps.cache-lavapipe.outputs.cache-hit != 'true'
uses: actions/checkout@v3
Expand Down
10 changes: 6 additions & 4 deletions docs/developer-guide/operators.md
Expand Up @@ -1026,15 +1026,17 @@ y0, hidden y1, cell y2 = lstm(x0, hidden x1, cell x2)

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | num_output | int | 0 | hidden size of output |
| 0 | num_output | int | 0 | output size of output |
| 1 | weight_data_size| int | 0 | total size of IFOG weight matrix |
| 2 | direction | int | 0 | 0=forward, 1=reverse, 2=bidirectional |
| 3 | hidden_size | int | num_output| hidden size |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| weight_xc_data| float/fp16/int8 | [input_size, num_output * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [num_output, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, num_output * 4, num_directions] |
| weight_xc_data| float/fp16/int8 | [input_size, hidden_size * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [hidden_size, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, hidden_size * 4, num_directions] |
| weight_hr_data| float/fp16/int8 | [hidden_size, num_output, num_directions] |

Direction flag:
- 0 = forward only
Expand Down
256 changes: 188 additions & 68 deletions src/layer/arm/lstm_arm.cpp

Large diffs are not rendered by default.

285 changes: 199 additions & 86 deletions src/layer/arm/lstm_arm_asimdhp.cpp

Large diffs are not rendered by default.

97 changes: 70 additions & 27 deletions src/layer/lstm.cpp
Expand Up @@ -29,43 +29,60 @@ int LSTM::load_param(const ParamDict& pd)
num_output = pd.get(0, 0);
weight_data_size = pd.get(1, 0);
direction = pd.get(2, 0);
hidden_size = pd.get(3, num_output);
return 0;
}

int LSTM::load_model(const ModelBin& mb)
{
int num_directions = direction == 2 ? 2 : 1;

int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

// raw weight data
weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
weight_xc_data = mb.load(size, hidden_size * 4, num_directions, 0);
if (weight_xc_data.empty())
return -100;

bias_c_data = mb.load(num_output, 4, num_directions, 0);
bias_c_data = mb.load(hidden_size, 4, num_directions, 0);
if (bias_c_data.empty())
return -100;

weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
weight_hc_data = mb.load(num_output, hidden_size * 4, num_directions, 0);
if (weight_hc_data.empty())
return -100;

if (num_output != hidden_size)
{
weight_hr_data = mb.load(hidden_size, num_output, num_directions, 0);
if (weight_hr_data.empty())
return -100;
}

return 0;
}

static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
Expand All @@ -80,7 +97,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

const float* x = bottom_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
Expand All @@ -90,15 +107,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float* gates_data = gates.row(q);

// gate I F O G
const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

float I = bias_c_I[q];
float F = bias_c_F[q];
Expand Down Expand Up @@ -140,7 +157,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

Expand All @@ -157,8 +174,34 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;
hidden_state[q] = H;
output_data[q] = H;

if (num_output == hidden_size)
{
hidden_state[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_state[q] = H;
}
}

if (num_output != hidden_size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
{
const float* hr = weight_hr.row(q);

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_state[i] * hr[i];
}

hidden_state[q] = H;
output_data[q] = H;
}
}
}

Expand All @@ -177,7 +220,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
Expand All @@ -189,7 +232,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
Expand All @@ -204,14 +247,14 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

Expand Down Expand Up @@ -251,7 +294,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
Expand All @@ -265,7 +308,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
Expand All @@ -282,13 +325,13 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;

Expand Down
2 changes: 2 additions & 0 deletions src/layer/lstm.h
Expand Up @@ -36,10 +36,12 @@ class LSTM : public Layer
int num_output;
int weight_data_size;
int direction; // 0=forward 1=reverse 2=bidirectional
int hidden_size;

Mat weight_hc_data;
Mat weight_xc_data;
Mat bias_c_data;
Mat weight_hr_data;
};

} // namespace ncnn
Expand Down

0 comments on commit 77eda4c

Please sign in to comment.