Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 45 additions & 46 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,6 @@ class EmbeddingVar : public ResourceBase {
return storage_->Get(key, value_ptr);
}

void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
const K* keys,
void** value_ptr_list,
int64 num_of_keys) {
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
}

Status LookupOrCreateKey(K key, void** value_ptr,
bool* is_filter, bool indices_as_pointer,
int64 count = 1) {
Expand All @@ -167,45 +160,6 @@ class EmbeddingVar : public ResourceBase {
return Status::OK();
}

Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
const K* keys,
void** value_ptrs,
int64 num_of_keys,
int64* indices_counts,
bool indices_as_pointer = false) {
if (indices_as_pointer) {
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
value_ptrs[i] = (void*)keys[i];
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
lookup_key_and_set_version_fn);
} else {
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
}

if (indices_counts != nullptr) {
auto add_freq_fn = [this, value_ptrs, indices_counts]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
add_freq_fn);
}
return Status::OK();
}


Status LookupOrCreateKey(K key, void** value_ptr) {
Status s = storage_->GetOrCreate(key, value_ptr);
TF_CHECK_OK(s);
Expand Down Expand Up @@ -402,6 +356,51 @@ class EmbeddingVar : public ResourceBase {

storage_->AddToCache(keys_tensor);
}

void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
const K* keys,
void** value_ptr_list,
int64 num_of_keys) {
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
}

Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
const K* keys,
void** value_ptrs,
int64 num_of_keys,
int64* indices_counts,
bool indices_as_pointer = false) {
if (indices_as_pointer) {
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
value_ptrs[i] = (void*)keys[i];
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
lookup_key_and_set_version_fn);
} else {
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
}

if (indices_counts != nullptr) {
auto add_freq_fn = [this, value_ptrs, indices_counts]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
add_freq_fn);
}
return Status::OK();
}
#endif

#if GOOGLE_CUDA
Expand Down