Skip to content

Commit

Permalink
[Embedding] Fix HBM_DRAM Restore core when id_num > cachesize
Browse files Browse the repository at this point in the history
Signed-off-by: RobertLou <2874395462@qq.com>
  • Loading branch information
RobertLou committed Apr 22, 2024
1 parent 04413cf commit 4e7c352
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tensorflow/core/framework/embedding/hbm_dram_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {

~HbmDramStorage() override {
MultiTierStorage<K, V>::DeleteFromEvictionManager();
//delete restore_cache_;
delete hbm_;
delete dram_;
delete dram_feat_desc_;
Expand Down Expand Up @@ -227,7 +228,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
}

void BatchEviction() override {
constexpr int EvictionSize = 10000;
constexpr int EvictionSize = 5000;
K evic_ids[EvictionSize];
if (!MultiTierStorage<K, V>::ready_eviction_) {
return;
Expand Down Expand Up @@ -287,16 +288,18 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
partition_id, partition_num,
is_incr, reset_version, reader);

restore_cache_.reset(CacheFactory::Create<K>(CacheStrategy::LFU, "ads"));
restorer.RestoreCkpt(emb_config, device);

int64 num_of_hbm_ids =
std::min(MultiTierStorage<K, V>::cache_capacity_,
(int64)MultiTierStorage<K, V>::cache_->size());
(int64)restore_cache_->size());

if (num_of_hbm_ids > 0) {
K* hbm_ids = new K[num_of_hbm_ids];
int64* hbm_freqs = new int64[num_of_hbm_ids];
int64* hbm_versions = nullptr;
MultiTierStorage<K, V>::cache_->get_cached_ids(hbm_ids, num_of_hbm_ids,
restore_cache_->get_cached_ids(hbm_ids, num_of_hbm_ids,
hbm_versions, hbm_freqs);
ImportToHbm(hbm_ids, num_of_hbm_ids, value_len, emb_config.emb_index);
MultiTierStorage<K, V>::cache_thread_pool_->Schedule(
Expand Down Expand Up @@ -329,10 +332,10 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
Status s = filter->Restore(key_num, bucket_num, partition_id,
partition_num, value_len, is_filter,
true/*to_dram*/, is_incr, restore_buff);

MultiTierStorage<K, V>::cache_->update((K*)restore_buff.key_buffer, key_num,
(int64*)restore_buff.version_buffer,
(int64*)restore_buff.freq_buffer);
restore_cache_->update((K*)restore_buff.key_buffer, key_num,
(int64*)restore_buff.version_buffer,
(int64*)restore_buff.freq_buffer);
return s;
}

Expand Down Expand Up @@ -574,6 +577,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
DramStorage<K, V>* dram_ = nullptr;
FeatureDescriptor<V>* hbm_feat_desc_ = nullptr;
FeatureDescriptor<V>* dram_feat_desc_ = nullptr;
std::unique_ptr<BatchCache<K>> restore_cache_ = nullptr;
Allocator* gpu_alloc_;
const int copyback_flag_offset_bits_ = 60;
};
Expand Down

0 comments on commit 4e7c352

Please sign in to comment.