Skip to content

Commit

Permalink
Merge branch 'huikang/hkv_embedding_init_optimize' into 'main'
Browse files Browse the repository at this point in the history
Huikang/hkv embedding init optimize

See merge request dl/hugectr/hugectr!1482
  • Loading branch information
minseokl committed Oct 19, 2023
2 parents e3e1d47 + fd11395 commit 6171060
Showing 1 changed file with 50 additions and 27 deletions.
77 changes: 50 additions & 27 deletions sparse_operation_kit/experiment/variable/impl/hkv_variable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "common/check.h"
#include "variable/impl/hkv_variable.h"
#define SM_NUM 108
#define NTHREAD_PER_SM 2048

namespace sok {

Expand Down Expand Up @@ -50,20 +52,29 @@ __global__ void generate_uniform_kernel(curandState* state, T* result, size_t n)
}

template <typename T>
__global__ void generate_uniform_kernel(curandState* state, T** result, bool* d_found, size_t dim) {
__global__ void generate_uniform_kernel(curandState* state, T** result, bool* d_found, size_t dim,
size_t num_embedding) {
auto id = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;
size_t emb_id = blockIdx.x;
size_t block_id = blockIdx.x;
size_t emb_vec_id = threadIdx.x;
/* Copy state to local memory for efficiency */
curandState localState = state[GlobalThreadId()];
/* Generate pseudo-random uniforms */
if (!d_found[emb_id]) {
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
result[emb_id][i] = curand_uniform_double(&localState);
curandState localState;
bool load_state = false;
for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) {
if (!d_found[emb_id]) {
if (!load_state) {
localState = state[GlobalThreadId()];
load_state = true;
}
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
result[emb_id][i] = curand_normal_double(&localState);
}
}
}
/* Copy state back to global memory */
state[GlobalThreadId()] = localState;
if (load_state) {
state[GlobalThreadId()] = localState;
}
}

template <typename T>
Expand Down Expand Up @@ -100,21 +111,30 @@ __global__ void generate_normal_kernel(curandState* state, T* result, size_t n)
}

template <typename T>
__global__ void generate_normal_kernel(curandState* state, T** result, bool* d_found, size_t dim) {
__global__ void generate_normal_kernel(curandState* state, T** result, bool* d_found, size_t dim,
size_t num_embedding) {
auto id = static_cast<size_t>(blockDim.x) * blockIdx.x + threadIdx.x;

size_t emb_id = blockIdx.x;
size_t block_id = blockIdx.x;
size_t emb_vec_id = threadIdx.x;
/* Copy state to local memory for efficiency */
curandState localState = state[GlobalThreadId()];
/* Generate pseudo-random normals */
if (!d_found[emb_id]) {
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
result[emb_id][i] = curand_normal_double(&localState);
curandState localState;
bool load_state = false;
for (size_t emb_id = block_id; emb_id < num_embedding; emb_id += gridDim.x) {
if (!d_found[emb_id]) {
if (!load_state) {
localState = state[GlobalThreadId()];
load_state = true;
}
for (size_t i = emb_vec_id; i < dim; i += blockDim.x) {
result[emb_id][i] = curand_normal_double(&localState);
}
}
}

/* Copy state back to global memory */
state[GlobalThreadId()] = localState;
if (load_state) {
state[GlobalThreadId()] = localState;
}
}

static void set_curand_states(curandState** states, cudaStream_t stream = 0) {
Expand Down Expand Up @@ -268,12 +288,13 @@ void HKVVariable<KeyType, ValueType>::lookup(const KeyType* keys, ValueType* val

int64_t dim = cols();

uint32_t grid_dim = SM_NUM*(NTHREAD_PER_SM/1024);
if ((num_keys * static_cast<size_t>(dim))/1024 < grid_dim) grid_dim = (num_keys * static_cast<size_t>(dim))/1024;
if (initializer_ == "normal" || initializer_ == "random") {
generate_normal_kernel<<<(num_keys * dim + 1024 - 1) / 1024, 1024, 0, stream>>>(curand_states_, values,
num_keys * dim);

generate_normal_kernel<<<grid_dim, 1024, 0, stream>>>(curand_states_, values,num_keys * dim);
} else if (initializer_ == "uniform") {
generate_uniform_kernel<<<(num_keys * dim + 1024 - 1) / 1024, 1024, 0, stream>>>(
curand_states_, values, num_keys * dim);
generate_uniform_kernel<<<grid_dim , 1024, 0, stream>>>(curand_states_, values, num_keys * dim);
} else {
try {
float val = std::stof(initializer_);
Expand All @@ -296,18 +317,20 @@ void HKVVariable<KeyType, ValueType>::lookup(const KeyType* keys, ValueType** va
CUDACHECK(cudaMemset(d_found, 0, num_keys * sizeof(bool)));
hkv_table_->find_or_insert(num_keys, keys, values, d_found, nullptr, stream);
//CUDACHECK(cudaStreamSynchronize(stream));

int64_t dim = cols();
uint32_t block_dim = max(dim, static_cast<int64_t>(32));
uint32_t grid_dim = SM_NUM*(NTHREAD_PER_SM/block_dim);
if (num_keys<grid_dim) grid_dim = num_keys;
if (initializer_ == "normal" || initializer_ == "random") {
generate_normal_kernel<<<num_keys, max(dim, static_cast<int64_t>(32)), 0, stream>>>(
curand_states_, values, d_found, dim);
generate_normal_kernel<<<grid_dim, block_dim, 0, stream>>>(
curand_states_, values, d_found, dim,num_keys);
} else if (initializer_ == "uniform") {
generate_uniform_kernel<<<num_keys, max(dim, static_cast<int64_t>(32)), 0, stream>>>(
curand_states_, values, d_found, dim);
generate_uniform_kernel<<<grid_dim, block_dim, 0, stream>>>(
curand_states_, values, d_found, dim,num_keys);
} else {
try {
float val = std::stof(initializer_);
const_initializer_kernel<<<num_keys, max(dim, static_cast<int64_t>(32)), 0, stream>>>(
const_initializer_kernel<<<num_keys, block_dim, 0, stream>>>(
val, values, d_found, dim);
} catch (std::invalid_argument& err) {
throw std::runtime_error("Unrecognized initializer {" + initializer_ + "}");
Expand Down

0 comments on commit 6171060

Please sign in to comment.