Skip to content

Commit

Permalink
Merge pull request #10164 from Yancey1989/lookup_sparse_table_op
Browse files Browse the repository at this point in the history
add lookup_sparse_table_op
  • Loading branch information
Yancey committed May 2, 2018
2 parents 1945b72 + 1a93253 commit ff99d94
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 22 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/lod_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/framework/selected_rows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}

std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
framework::Tensor* value) const {
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<int64_t> non_keys;
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
Expand All @@ -133,15 +133,15 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys.push_back(keys[i]);
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
return non_keys;
return non_keys_pair;
}

bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/selected_rows.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/lod_tensor.h"
Expand Down Expand Up @@ -78,10 +79,11 @@ class SelectedRows {
/*
* @brief Get value by the key list, if the
*
* @return a list of keys which does not exists in table
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std::vector<int64_t> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const;
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
framework::Tensor* value) const;

/*
* @brief Set a key-value pair into the table.
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/framework/selected_rows_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
}

TEST_F(SelectedRowsTester, Table) {
TEST_F(SelectedRowsTester, SparseTable) {
platform::CPUPlace cpu;
SelectedRows table;
// initialize a sparse table
Expand Down Expand Up @@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) {
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key});
auto non_keys = table.Get(keys, &get_value);
auto non_key_pairs = table.Get(keys, &get_value);

ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1));
ASSERT_EQ(non_keys[0], non_key);
ASSERT_EQ(non_key_pairs.size(), static_cast<size_t>(1));
ASSERT_EQ(non_key_pairs[0].first, non_key);
}

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
Expand Down
165 changes: 165 additions & 0 deletions paddle/fluid/operators/lookup_sparse_table_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace operators {

constexpr int64_t kNoPadding = -1;

class LookupSparseTableInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupSparseTableOp should not be null.");
auto shape_w = ctx->GetInputDim("W");
auto shape_ids = ctx->GetInputDim("Ids");
shape_w[0] = shape_ids.size();
ctx->SetOutputDim("Out", shape_w);
}
};

class LookupSparseTableOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;

private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto out_var = scope.FindVar(Output("Out"));
auto w_var = scope.FindVar(Input("W"));
auto ids_var = scope.FindVar(Input("Ids"));
unsigned int seed = static_cast<unsigned int>(Attr<int>("seed"));
float min = Attr<float>("min");
float max = Attr<float>("max");
bool auto_grown_table = Attr<bool>("auto_grown_table");

PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
"The type of Out var should be LodTensor.");
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be LoDTensor.");
auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>();
std::vector<int64_t> keys;
keys.resize(ids_t.numel());
for (size_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i];
}

// TODO(Yancey1989): support CUDA Place for the sparse table
platform::CPUPlace cpu;
auto out_shape = w_t->value().dims();
out_shape[0] = keys.size();
out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
"The sparse table only support FP32");
auto non_keys_pair = w_t->Get(keys, out_t);
if (!auto_grown_table) {
PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast<size_t>(0),
"there is some keys does exists in the sparse table.");
}
auto value_shape = w_t->value().dims();
value_shape[0] = 1;
for (const auto &it : non_keys_pair) {
const auto key = it.first;
const auto index = it.second;
framework::Tensor value;
value.Resize(value_shape);
auto data = value.mutable_data<float>(cpu);

std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<float> dist(min, max);
int64_t size = value.numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
w_t->Set(key, value);
memory::Copy(cpu, out_t->mutable_data<float>(cpu) + index * value.numel(),
cpu, value.data<float>(), value.numel() * sizeof(float));
}
}
};

class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter.");
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.");
AddOutput("Out",
"(LoDTensor) The lookup results, which have the "
"same type as W.");
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximun value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<bool>("auto_grown_table",
"(bool default false)"
"Whether create new value if for nonexistent key.")
.SetDefault(true);
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_sparse_table, ops::LookupSparseTableOp,
ops::LookupSparseTableInferShape,
ops::LookupSparseTableOpMaker,
paddle::framework::EmptyGradOpMaker);
21 changes: 20 additions & 1 deletion paddle/fluid/operators/sgd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
}
};

class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
}
}
}
};

class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand All @@ -74,5 +92,6 @@ This operator implements one step of the stochastic gradient descent algorithm.
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
24 changes: 22 additions & 2 deletions paddle/fluid/operators/uniform_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,31 @@ uniform distribution.
.SetDefault(framework::proto::VarType::FP32);
}
};

class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker);
REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::UniformRandomOpVarTypeInference);

REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
Expand Down
13 changes: 9 additions & 4 deletions python/paddle/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
shape=trainer_out.shape,
dtype=trainer_out.dtype)
prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE,
type="lookup_sparse_table",
inputs={'Ids': pserver_ids,
"W": table_var},
outputs={"Out": pserver_out},
Expand All @@ -685,9 +685,14 @@ def _clone_var(block, var, persistable=True):

# STEP: create table optimize block
# create table param and grad var in pserver program
param_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[self.table_name])
origin_param_var = self.origin_program.global_block().vars[
self.table_name]
param_var = pserver_program.global_block().create_var(
name=origin_param_var.name,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
grad_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name(
Expand Down
Loading

0 comments on commit ff99d94

Please sign in to comment.