Skip to content

Commit

Permalink
search and fill slot_feature (PaddlePaddle#20)
Browse files Browse the repository at this point in the history
* search and fill slot_feature

* search and fill slot_feature, fix compile error

* search and fill slot_feature, rename 8 as slot_num_

Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
huwei02 and root committed Jun 8, 2022
1 parent 750e343 commit 1816fc2
Show file tree
Hide file tree
Showing 12 changed files with 520 additions and 31 deletions.
75 changes: 67 additions & 8 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ int32_t GraphTable::Load_to_ssd(const std::string &path,
}

paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
int ntype_id, std::vector<uint64_t> &node_ids, int slot_num) {
std::vector<uint64_t> &node_ids, int slot_num) {
std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (auto x : node_ids) {
int location = x % shard_num % task_pool_size_;
Expand All @@ -63,7 +63,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) {
// TODO use FEATURE_TABLE instead
Node *v = find_node(1, ntype_id, bags[i][j]);
Node *v = find_node(1, bags[i][j]);
x.node_id = bags[i][j];
if (v == NULL) {
x.feature_size = 0;
Expand All @@ -85,10 +85,6 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
}
x.feature_size = total_feature_size;
node_fea_array[i].push_back(x);
VLOG(2) << "node_fea_array[i].size() = ["
<< node_fea_array[i].size() << "]";
VLOG(2) << "feature_array[i].size() = [" << feature_array[i].size()
<< "]";
}
}
return 0;
Expand All @@ -102,8 +98,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
tot_len += feature_array[i].size();
}
VLOG(0) << "Loaded feature table on cpu, feature_list_size[" << tot_len
<< "] node_ids_size[" << node_ids.size() << "] ntype_id[" << ntype_id
<< "]";
<< "] node_ids_size[" << node_ids.size() << "]";
res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num);
unsigned int offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
Expand Down Expand Up @@ -1240,6 +1235,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,

return 0;
}

Node *GraphTable::find_node(int type_id, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
Node *node = nullptr;
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
for (auto& search_shard: search_shards) {
PADDLE_ENFORCE_NOT_NULL(search_shard[index]);
node = search_shard[index]->find_node(id);
if (node != nullptr) {
break;
}
}
return node;
}

Node *GraphTable::find_node(int type_id, int idx, uint64_t id) {
size_t shard_id = id % shard_num;
Expand Down Expand Up @@ -1537,6 +1550,30 @@ std::pair<int32_t, std::string> GraphTable::parse_feature(
return std::make_pair<int32_t, std::string>(-1, "");
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j]() -> std::vector<uint64_t> {
return search_shards[idx][j]->get_all_id();
}));
}
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) {
res[(uint64_t)(id) % slice_num].push_back(id);
}
}
return res;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
Expand All @@ -1559,6 +1596,28 @@ std::vector<std::vector<uint64_t>> GraphTable::get_all_id(int type_id, int idx,
}
return res;
}

std::vector<std::vector<uint64_t>> GraphTable::get_all_feature_ids(int type_id, int idx,
int slice_num) {
std::vector<std::vector<uint64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (int i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i]() -> std::vector<uint64_t> {
return search_shards[i]->get_all_feature_ids();
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[id % slice_num].push_back(id);
}
return res;
}

int32_t GraphTable::pull_graph_list(int type_id, int idx, int start,
int total_size,
std::unique_ptr<char[]> &buffer,
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ class GraphShard {
}
return res;
}
std::vector<uint64_t> get_all_feature_ids() {
// TODO by huwei02, dedup
std::vector<uint64_t> total_res;
for (int i = 0; i < (int)bucket.size(); i++) {
std::vector<uint64_t> res;
res.push_back(bucket[i]->get_feature_ids(&res));
total_res.insert(total_res.end(), res.begin(), res.end());
}
return total_res;
}
GraphNode *add_graph_node(uint64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id);
Expand Down Expand Up @@ -475,8 +485,11 @@ class GraphTable : public Table {
int32_t load_edges(const std::string &path, bool reverse,
const std::string &edge_type);

std::vector<std::vector<uint64_t>> get_all_id(int type, int slice_num);
std::vector<std::vector<uint64_t>> get_all_id(int type, int idx,
int slice_num);
std::vector<std::vector<uint64_t>> get_all_feature_ids(int type, int idx,
int slice_num);
int32_t load_nodes(const std::string &path, std::string node_type);

int32_t add_graph_node(int idx, std::vector<uint64_t> &id_list,
Expand All @@ -486,6 +499,7 @@ class GraphTable : public Table {

int32_t get_server_index_by_id(uint64_t id);
Node *find_node(int type_id, int idx, uint64_t id);
Node *find_node(int type_id, uint64_t id);

virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
Expand Down Expand Up @@ -561,7 +575,7 @@ class GraphTable : public Table {
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
int idx, std::vector<uint64_t> ids);
virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
int ntype_id, std::vector<uint64_t> &node_ids, int slot_num);
std::vector<uint64_t> &node_ids, int slot_num);
int32_t Load_to_ssd(const std::string &path, const std::string &param);
int64_t load_graph_to_memory_from_ssd(int idx, std::vector<uint64_t> &ids);
int32_t make_complementary_graph(int idx, int64_t byte_size);
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/distributed/ps/table/graph/graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Node {
virtual void to_buffer(char *buffer, bool need_feature);
virtual void recover_from_buffer(char *buffer);
virtual std::string get_feature(int idx) { return std::string(""); }
virtual int get_feature_ids(std::vector<uint64_t> *res) const {
return 0;
}
virtual int get_feature_ids(int slot_idx, std::vector<uint64_t> *res) const {
return 0;
}
Expand Down Expand Up @@ -102,6 +105,25 @@ class FeatureNode : public Node {
}
}

virtual int get_feature_ids(std::vector<uint64_t> *res) const {
PADDLE_ENFORCE_NOT_NULL(res);
res->clear();
errno = 0;
for (auto& feature_item: feature) {
const char *feat_str = feature_item.c_str();
auto fields = paddle::string::split_string<std::string>(feat_str, " ");
char *head_ptr = NULL;
for (auto &field : fields) {
PADDLE_ENFORCE_EQ(field.empty(), false);
uint64_t feasign = strtoull(field.c_str(), &head_ptr, 10);
PADDLE_ENFORCE_EQ(field.c_str() + field.length(), head_ptr);
res->push_back(feasign);
}
}
PADDLE_ENFORCE_EQ(errno, 0);
return 0;
}

virtual int get_feature_ids(int slot_idx, std::vector<uint64_t> *res) const {
PADDLE_ENFORCE_NOT_NULL(res);
res->clear();
Expand Down
Loading

0 comments on commit 1816fc2

Please sign in to comment.