Skip to content

Commit

Permalink
add embedding 2.0 (#26649) (#26903)
Browse files Browse the repository at this point in the history
* add embedding 2.0

* add embedding support input int32
  • Loading branch information
seiriosPlus committed Sep 2, 2020
1 parent 89ef291 commit 1b60f7f
Show file tree
Hide file tree
Showing 12 changed files with 548 additions and 125 deletions.
13 changes: 12 additions & 1 deletion paddle/fluid/operators/lookup_table_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/lookup_table_v2_op.h"

#include <memory>

#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"

namespace paddle {
Expand Down Expand Up @@ -196,3 +196,14 @@ REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lookup_table_v2)
.AddCheckpoint(
R"ROC(fix lookup_table_v2, add input type `int32`)ROC",
paddle::framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged("lookup_table_v2 support input type "
"`int64`; after support input type "
"`int32/int64`"));

/* ========================================================================== */
73 changes: 59 additions & 14 deletions paddle/fluid/operators/lookup_table_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
}
}

template <typename T>
__global__ void InputTypeCovert(const T *in_ids, const int64_t K,
int64_t *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = (int64_t)(in_ids[i]);
}
}

template <typename T>
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
Expand All @@ -101,23 +109,37 @@ class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();

auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

dim3 threads(256, 4);
dim3 grids(80, 1);

// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);

const int64_t *ids_p = nullptr;

if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}

auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

if (padding_idx == -1)
LookupTableV2<
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
output, table, ids_p, N, K, D, padding_idx);
else
LookupTableV2<
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
output, table, ids_p, N, K, D, padding_idx);
}
};

Expand All @@ -139,16 +161,24 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {

auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();

dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());

// TODO(yuyang18): Strange code here.
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
if (ids->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids->data<int>(), ids_num,
new_rows.MutableData(context.GetPlace()));
} else {
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
}

d_table->set_rows(new_rows);

auto *d_table_value = d_table->mutable_value();
Expand Down Expand Up @@ -177,17 +207,32 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t *ids = ids_t->data<int64_t>();

dim3 threads(128, 8);
dim3 grids(8, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);

const int64_t *ids_p = nullptr;

if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}

const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());

auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
d_table, d_output, ids_p, N, K, D);
}
}
};
Expand Down
173 changes: 89 additions & 84 deletions paddle/fluid/operators/lookup_table_v2_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <algorithm>
#include <string>
#include <vector>

Expand Down Expand Up @@ -45,84 +46,70 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");

auto id_name = context.InputNames("Ids").front();
auto embedding_name = context.InputNames("W").front();
auto out_name = context.OutputNames("Out").front();

// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t ids_numel = ids_t->numel();

if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter server
std::vector<int64_t> ids;
ids.reserve(ids_numel);

#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, embedding_name, false,
table_names, epmap, context,
context.scope());
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_numel,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();

if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];

auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i], row_number,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
framework::TensorToVector(*ids_t, &ids);
}

if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];

auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i], row_number,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());

auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0,
"the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
Expand Down Expand Up @@ -151,17 +138,23 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();

std::vector<int64_t> ids;
ids.reserve(ids_num);

auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}

std::vector<int64_t> new_rows;
new_rows.resize(ids_num);
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
d_table->set_rows(new_rows);
d_table->set_rows(ids);

auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
Expand All @@ -185,11 +178,23 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());

} else {
auto *ids = context.Input<LoDTensor>("Ids");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();

std::vector<int64_t> ids;
ids.reserve(ids_num);

if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}

auto *ids_data = ids->data<int64_t>();
auto *ids_data = ids.data();

int64_t N = table_dim[0];
int64_t D = table_dim[1];
Expand All @@ -199,7 +204,7 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {

memset(d_table_data, 0, d_table->numel() * sizeof(T));

for (int64_t i = 0; i < ids->numel(); ++i) {
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def one_hot(input, depth, allow_out_of_range=False):
return one_hot_out


@deprecated(since='2.0.0', update_to='paddle.nn.functional.embedding')
def embedding(input,
size,
is_sparse=False,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def fc(input,
return helper.append_activation(pre_activation)


@deprecated(since="2.0.0", update_to="paddle.nn.functional.embedding")
def embedding(input,
size,
is_sparse=False,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_adam_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def test_adam_op_with_state_dict(self):

import paddle
paddle.disable_static()
emb = paddle.nn.Embedding([10, 10])
emb = paddle.nn.Embedding(10, 10)

adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters())
state_dict = adam.state_dict()
Expand Down
Loading

0 comments on commit 1b60f7f

Please sign in to comment.