Skip to content

Commit

Permalink
[Embedding] Implement of Multi-Level Embedding.
Browse files Browse the repository at this point in the history
  • Loading branch information
candyzone committed Mar 7, 2022
1 parent 00f3980 commit ccb8450
Show file tree
Hide file tree
Showing 18 changed files with 1,013 additions and 749 deletions.
5 changes: 5 additions & 0 deletions modelzoo/features/pmem/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def main(_):
storage_type=config_pb2.StorageType.PMEM_LIBPMEM,
storage_path=FLAGS.ev_storage_path,
storage_size=FLAGS.ev_storage_size_gb * 1024 * 1024 * 1024))
elif FLAGS.ev_storage == "dram_pmem":
ev_option = variables.EmbeddingVariableOption(storage_option=variables.StorageOption(
storage_type=config_pb2.StorageType.DRAM_PMEM,
storage_path=FLAGS.ev_storage_path,
storage_size=FLAGS.ev_storage_size_gb * 1024 * 1024 * 1024))
fm_w = tf.get_embedding_variable(
name='fm_w{}'.format(sidx),
embedding_dim=1,
Expand Down
176 changes: 176 additions & 0 deletions tensorflow/core/framework/embedding/cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_
#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_
#include <iostream>
#include <map>
#include <unordered_map>
#include <set>
#include <list>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {
namespace embedding {

template <class K>
class BatchCache {
public:
BatchCache() {}
virtual size_t get_evic_ids(K* evic_ids, size_t k_size) = 0;
virtual void add_to_rank(const K* batch_ids, size_t batch_size) = 0;
virtual size_t size() = 0;
};

template <class K>
class LRUCache : public BatchCache<K> {
private:
class LRUNode {
public:
K id;
LRUNode *pre, *next;
LRUNode(K id) : id(id), pre(nullptr), next(nullptr) {}
};
LRUNode *head, *tail;
std::map<K, LRUNode *> mp;
mutex mu_;

public:
LRUCache() {
mp.clear();
head = new LRUNode(0);
tail = new LRUNode(0);
head->next = tail;
tail->pre = head;
}

size_t size() {
mutex_lock l(mu_);
return mp.size();
}

size_t get_evic_ids(K* evic_ids, size_t k_size) {
mutex_lock l(mu_);
size_t true_size = 0;
LRUNode *evic_node = tail->pre;
LRUNode *rm_node = evic_node;
for (size_t i = 0; i < k_size && evic_node != head; ++i) {
evic_ids[i] = evic_node->id;
rm_node = evic_node;
evic_node = evic_node->pre;
mp.erase(rm_node->id);
delete rm_node;
true_size++;
}
evic_node->next = tail;
tail->pre = evic_node;
return true_size;
}

void add_to_rank(const K* batch_ids, size_t batch_size) {
mutex_lock l(mu_);
for (size_t i = 0; i < batch_size; ++i) {
K id = batch_ids[i];
typename std::map<K, LRUNode *>::iterator it = mp.find(id);
if (it != mp.end()) {
LRUNode *node = it->second;
node->pre->next = node->next;
node->next->pre = node->pre;
head->next->pre = node;
node->next = head->next;
head->next = node;
node->pre = head;
} else {
LRUNode *newNode = new LRUNode(id);
head->next->pre = newNode;
newNode->next = head->next;
head->next = newNode;
newNode->pre = head;
mp[id] = newNode;
}
}
}
};

template <class K>
class LFUCache : public BatchCache<K> {
private:
class LFUNode {
public:
K key;
size_t freq;
LFUNode(K key, size_t freq) : key(key), freq(freq) {}
};
size_t min_freq;
size_t max_freq;
std::unordered_map<K, typename std::list<LFUNode>::iterator> key_table;
std::unordered_map<K, typename std::list<LFUNode>> freq_table;
mutex mu_;

public:
LFUCache() {
min_freq = 0;
max_freq = 0;
key_table.clear();
freq_table.clear();
}

size_t size() {
mutex_lock l(mu_);
return key_table.size();
}

size_t get_evic_ids(K *evic_ids, size_t k_size) {
mutex_lock l(mu_);
size_t true_size = 0;
for (size_t i = 0; i < k_size; ++i) {
auto rm_it = freq_table[min_freq].back();
key_table.erase(rm_it.key);
evic_ids[i] = rm_it.key;
++true_size;
freq_table[min_freq].pop_back();
if (freq_table[min_freq].size() == 0) {
freq_table.erase(min_freq);
++min_freq;
while (min_freq <= max_freq) {
auto it = freq_table.find(min_freq);
if (it == freq_table.end() || it->second.size() == 0) {
++min_freq;
} else {
break;
}
}
}
}
return true_size;
}

void add_to_rank(const K *batch_ids, size_t batch_size) {
mutex_lock l(mu_);
for (size_t i = 0; i < batch_size; ++i) {
K id = batch_ids[i];
auto it = key_table.find(id);
if (it == key_table.end()) {
freq_table[1].push_front(LFUNode(id, 1));
key_table[id] = freq_table[1].begin();
min_freq = 1;
} else {
typename std::list<LFUNode>::iterator node = it->second;
size_t freq = node->freq;
freq_table[freq].erase(node);
if (freq_table[freq].size() == 0) {
freq_table.erase(freq);
if (min_freq == freq)
min_freq += 1;
}
max_freq = std::max(max_freq, freq + 1);
freq_table[freq + 1].push_front(LFUNode(id, freq + 1));
key_table[id] = freq_table[freq + 1].begin();
}
}
}
};

} // embedding
} // tensorflow

#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_
6 changes: 3 additions & 3 deletions tensorflow/core/framework/embedding/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ enum StorageType {
PMEM_MEMKIND = 2;
PMEM_LIBPMEM = 3;
SSD = 4;
LEVELDB = 5;

LEVELDB = 14;
/*
// two level
DRAM_PMEM = 11;
DRAM_SSD = 12;
HBM_DRAM = 13;
DRAM_LEVELDB = 14;

// three level
DRAM_PMEM_SSD = 101;
HBM_DRAM_SSD = 102;
*/

}
1 change: 0 additions & 1 deletion tensorflow/core/framework/embedding/dense_hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class DenseHashMap : public KVInterface<K, V> {
hash_map_[i].hash_map.set_empty_key(-1);
hash_map_[i].hash_map.set_deleted_key(-2);
}
KVInterface<K, V>::total_dims_ = 0;
}

~DenseHashMap() {
Expand Down
15 changes: 11 additions & 4 deletions tensorflow/core/framework/embedding/embedding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ struct EmbeddingConfig {
int64 storage_size;
int64 default_value_dim;
int normal_fix_flag;
bool is_multi_level;

EmbeddingConfig(int64 emb_index = 0, int64 primary_emb_index = 0,
int64 block_num = 1, int slot_num = 0,
const std::string& name = "", int64 steps_to_live = 0,
int64 filter_freq = 0, int64 max_freq = 999999,
float l2_weight_threshold = -1.0, const std::string& layout = "normal",
float l2_weight_threshold = -1.0, const std::string& layout = "normal_fix",
int64 max_element_size = 0, float false_positive_probability = -1.0,
DataType counter_type = DT_UINT64, embedding::StorageType storage_type = embedding::DRAM,
const std::string& storage_path = "", int64 storage_size = 0,
Expand All @@ -48,7 +49,8 @@ struct EmbeddingConfig {
storage_path(storage_path),
storage_size(storage_size),
default_value_dim(default_value_dim),
normal_fix_flag(0) {
normal_fix_flag(0),
is_multi_level(false) {
if ("normal" == layout) {
layout_type = LayoutType::NORMAL;
} else if ("light" == layout) {
Expand All @@ -61,16 +63,21 @@ struct EmbeddingConfig {
}
if (max_element_size != 0 && false_positive_probability != -1.0){
kHashFunc = calc_num_hash_func(false_positive_probability);
num_counter = calc_num_counter(max_element_size, false_positive_probability);
num_counter = calc_num_counter(max_element_size, false_positive_probability);
} else {
kHashFunc = 0;
num_counter = 0;
}
if (layout_type == LayoutType::NORMAL_FIX) {
normal_fix_flag = 1;
}
if (storage_type == embedding::PMEM_MEMKIND || storage_type == embedding::PMEM_LIBPMEM ||
storage_type == embedding::DRAM_PMEM || storage_type == embedding::DRAM_SSD ||
storage_type == embedding::HBM_DRAM || storage_type == embedding::DRAM_LEVELDB) {
is_multi_level = true;
}
}

int64 calc_num_counter(int64 max_element_size, float false_positive_probability) {
float loghpp = fabs(log(false_positive_probability));
float factor = log(2) * log(2);
Expand Down
Loading

0 comments on commit ccb8450

Please sign in to comment.