Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tensorflow/core/framework/embedding/cpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class LocklessHashMap : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
std::pair<const K, void*> *hash_map_dump;
int64 bucket_count;
Expand All @@ -147,11 +148,12 @@ class LocklessHashMap : public KVInterface<K, V> {
bucket_count = it.second;
for (int64 j = 0; j < bucket_count; j++) {
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
&& hash_map_dump[j].first % kSavedPartitionNum
% partition_nums != partition_id) {
key_list->emplace_back(hash_map_dump[j].first);
value_ptr_list->emplace_back(hash_map_dump[j].second);
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_) {
int part_id = hash_map_dump[j].first % kSavedPartitionNum % partition_nums;
if (part_id != partition_id) {
key_list[part_id].emplace_back(hash_map_dump[j].first);
value_ptr_list[part_id].emplace_back(hash_map_dump[j].second);
}
}
}

Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class DenseHashMap : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
dense_hash_map hash_map_dump[partition_num_];
for (int i = 0; i< partition_num_; i++) {
Expand All @@ -131,9 +132,10 @@ class DenseHashMap : public KVInterface<K, V> {
}
for (int i = 0; i< partition_num_; i++) {
for (const auto it : hash_map_dump[i].hash_map) {
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
key_list->push_back(it.first);
value_ptr_list->push_back(it.second);
int part_id = it.first % kSavedPartitionNum % partition_nums;
if (part_id != partition_id) {
key_list[part_id].emplace_back(it.first);
value_ptr_list[part_id].emplace_back(it.second);
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ class EmbeddingVar : public ResourceBase {
}
}

Status GetShardedSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list,
Status GetShardedSnapshot(std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_num) {
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
partition_id, partition_num);
Expand All @@ -546,7 +546,7 @@ class EmbeddingVar : public ResourceBase {
bool is_admit = feat_desc_->IsAdmit(value_ptr);
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);

if (!is_admit) {
if (is_admit) {
key_list[i] = tot_keys_list[i];

if (!is_in_dram) {
Expand All @@ -571,7 +571,7 @@ class EmbeddingVar : public ResourceBase {
}
} else {
if (!save_unfiltered_features)
return;
continue;
//TODO(JUNQI) : currently not export filtered keys
}

Expand All @@ -584,6 +584,7 @@ class EmbeddingVar : public ResourceBase {
feat_desc_->Deallocate(value_ptr);
}
}
return;
}

Status RestoreFromKeysAndValues(int64 key_num, int partition_id,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/gpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ class GPUHashMapKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot";
return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/kv_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class KVInterface {
std::vector<void*>* value_ptr_list) = 0;

virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) = 0;

virtual std::string DebugString() const = 0;
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/leveldb_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ class LevelDBKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
ReadOptions options;
options.snapshot = db_->GetSnapshot();
Expand All @@ -203,8 +204,9 @@ class LevelDBKV : public KVInterface<K, V> {
for (it->SeekToFirst(); it->Valid(); it->Next()) {
K key;
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
if (key % kSavedPartitionNum % partition_nums == partition_id) continue;
key_list->emplace_back(key);
int part_id = key % kSavedPartitionNum % partition_nums;
if (part_id == partition_id) continue;
key_list[part_id].emplace_back(key);
FeatureDescriptor<V> hbm_feat_desc(
1, 1, ev_allocator()/*useless*/,
StorageType::HBM_DRAM, true, true,
Expand All @@ -218,7 +220,7 @@ class LevelDBKV : public KVInterface<K, V> {
value_ptr, feat_desc_->GetFreq(dram_value_ptr));
hbm_feat_desc.UpdateVersion(
value_ptr, feat_desc_->GetVersion(dram_value_ptr));
value_ptr_list->emplace_back(value_ptr);
value_ptr_list[part_id].emplace_back(value_ptr);
}
delete it;
feat_desc_->Deallocate(dram_value_ptr);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class MultiTierStorage : public Storage<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage.";
return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/single_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ class SingleTierStorage : public Storage<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
mutex_lock l(Storage<K, V>::mu_);
return kv_->GetShardedSnapshot(key_list, value_ptr_list,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/ssd_hash_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ class SSDHashKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
return Status::OK();
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class Storage {
virtual Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) = 0;
virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) = 0;
virtual Status Save(
const string& tensor_name,
Expand Down