Skip to content

Commit

Permalink
fix t5 in sampling (#4624)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML committed Feb 7, 2023
1 parent 2b168ab commit 6afc928
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1582,9 +1582,9 @@ void topK_sampling_kernel_kernelLauncher_v2(void* workspace,
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;

if (workspace == nullptr) {
workspace_size = sizeof(T) * temp_log_probs_buf_size +
workspace_size = sizeof(float) * temp_log_probs_buf_size +
sizeof(int) * topk_tmp_ids_buf_size +
2 * sizeof(T) * topk_tmp_val_buf_size;
2 * sizeof(float) * topk_tmp_val_buf_size;
return;
} else {
T* temp_log_probs = (T*)workspace;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class T5DecodingSampling {
DataType_ *decoder_buf_;
DataType_ *decoder_normed_result_buf_;
DataType_ *embedding_buf_;
DataType_ *trans_out_buf_;
DataType_ *lm_normed_result_buf_;
DataType_ *logits_buf_;
int *word_ids_buf_;
bool *finished_buf_;
Expand Down Expand Up @@ -135,6 +133,7 @@ class T5DecodingSampling {

args_.num_bucket_ = num_bucket;
args_.max_distance_ = max_distance;
args_.tie_word_embeddings_ = tie_word_embeddings;

// For models without parallel
if (l_parallel_param_.layers_per_group == 0) {
Expand Down Expand Up @@ -255,8 +254,8 @@ class T5DecodingSampling {

size_t datatype_buf_size =
from_tensor_size * 2 + decoder_workspace_size +
(cache_size * 4 + mem_cache_size * 2) * args_.decoder_layers_ +
decoder_normed_result_buffer_size * 3;
(cache_size * 2 + mem_cache_size * 2) * args_.decoder_layers_ +
decoder_normed_result_buffer_size;

buf_ = reinterpret_cast<void *>(allocator_.malloc(
((sizeof(DataType_) == sizeof(half)) ? CUBLAS_WORKSPACE_SIZE : 0) +
Expand Down Expand Up @@ -287,8 +286,6 @@ class T5DecodingSampling {
i * mem_cache_size * 2 + mem_cache_size;
}

/* We use two-way buffer since we have to update KV buf at the end of each
* step. */
K_cache_[0] = V_mem_cache_[args_.decoder_layers_ - 1] + mem_cache_size +
0 * cache_size * args_.decoder_layers_;
V_cache_[0] = V_mem_cache_[args_.decoder_layers_ - 1] + mem_cache_size +
Expand Down Expand Up @@ -474,7 +471,7 @@ class T5DecodingSampling {

// TODO(guosheng): move cache offset into for loop for pipeline parallel
size_t cache_size = (args_.batch_size_ * args_.seq_len_ *
t_parallel_param_.local_hidden_units_); // type T
args_.hidden_units_); // type T

const int local_batch = l_parallel_param_.local_batch_size;
for (uint step = 1; step <= args_.seq_len_; ++step) {
Expand Down

0 comments on commit 6afc928

Please sign in to comment.