Skip to content

Commit

Permalink
[Serve] Introducing GPU sampler for CUDA
Browse files Browse the repository at this point in the history
This PR introduces the GPU sampler for CUDA only. The GPU sampler
makes use of the GPU sampling ops introduced in apache/tvm#16575.

We will follow up to benchmark the performance of the GPU sampler
over CPU sampler.
  • Loading branch information
MasterJH5574 committed Mar 12, 2024
1 parent 5bae24a commit 808da17
Show file tree
Hide file tree
Showing 15 changed files with 663 additions and 43 deletions.
10 changes: 5 additions & 5 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "model.h"
#include "request.h"
#include "request_state.h"
#include "sampler.h"
#include "sampler/sampler.h"

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -78,13 +78,13 @@ class EngineImpl : public Engine {
this->models_.push_back(model);
this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()});
}
int max_logit_processor_num_token = kv_cache_config_->max_num_sequence;
int max_num_tokens = kv_cache_config_->max_num_sequence;
if (engine_mode_->enable_speculative) {
max_logit_processor_num_token *= engine_mode_->spec_draft_length;
max_num_tokens *= engine_mode_->spec_draft_length;
}
LogitProcessor logit_processor =
this->models_[0]->CreateLogitProcessor(max_logit_processor_num_token, trace_recorder);
Sampler sampler = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_);
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
Sampler sampler = this->models_[0]->CreateSampler(max_num_tokens, trace_recorder);
// Step 3. Initialize engine actions that represent state transitions.
if (this->engine_mode_->enable_speculative) {
// Speculative decoding is only possible for more than one model.
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "../engine_state.h"
#include "../event_trace_recorder.h"
#include "../model.h"
#include "../sampler.h"
#include "../sampler/sampler.h"

namespace mlc {
namespace llm {
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "../../random.h"
#include "../config.h"
#include "../model.h"
#include "../sampler.h"
#include "../sampler/sampler.h"
#include "action.h"
#include "action_commons.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "../config.h"
#include "../model.h"
#include "../sampler.h"
#include "../sampler/sampler.h"
#include "action.h"
#include "action_commons.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "../../random.h"
#include "../config.h"
#include "../model.h"
#include "../sampler.h"
#include "../sampler/sampler.h"
#include "action.h"
#include "action_commons.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "../config.h"
#include "../model.h"
#include "../sampler.h"
#include "../sampler/sampler.h"
#include "action.h"
#include "action_commons.h"

Expand Down
6 changes: 6 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ void FunctionTable::_InitFunctions() {
this->kv_cache_popn_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn");
this->kv_cache_get_num_available_pages_func_ =
get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages");
if (local_gpu_device.device_type == DLDeviceType::kDLCUDA) {
gpu_multinomial_from_uniform_func_ = mod->GetFunction("multinomial_from_uniform", true);
gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true);
gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true);
gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true);
}
this->nd_view_func_ = get_global_func("vm.builtin.reshape");
this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of");
this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset");
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ struct FunctionTable {
PackedFunc kv_cache_attention_func_;
PackedFunc kv_cache_popn_func_;
PackedFunc kv_cache_get_num_available_pages_func_;
PackedFunc gpu_multinomial_from_uniform_func_;
PackedFunc gpu_argsort_probs_func_;
PackedFunc gpu_sample_with_top_p_func_;
PackedFunc gpu_sampler_take_probs_func_;
PackedFunc nd_view_func_;
PackedFunc nd_get_shape_func_;
PackedFunc nd_copy_embedding_to_offset_func_;
Expand Down
9 changes: 9 additions & 0 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ class ModelImpl : public ModelObj {
std::move(trace_recorder));
}

Sampler CreateSampler(int max_num_sample, Optional<EventTraceRecorder> trace_recorder) {
if (device_.device_type == DLDeviceType::kDLCUDA) {
return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_,
std::move(trace_recorder));
} else {
return Sampler::CreateCPUSampler(std::move(trace_recorder));
}
}

void CreateKVCache(KVCacheConfig kv_cache_config) final {
IntTuple max_num_sequence{kv_cache_config->max_num_sequence};
IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length};
Expand Down
8 changes: 8 additions & 0 deletions cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "event_trace_recorder.h"
#include "function_table.h"
#include "logit_processor.h"
#include "sampler/sampler.h"

namespace mlc {
namespace llm {
Expand All @@ -23,6 +24,9 @@ namespace serve {
using tvm::Device;
using namespace tvm::runtime;

// Declare the sampler class for `Model::CreateSampler`.
class Sampler;

/*!
* \brief The workspace tensors that may be shared across different
* calls to Model. For example, the prefill action use the `embeddings`
Expand Down Expand Up @@ -144,6 +148,10 @@ class ModelObj : public Object {
virtual LogitProcessor CreateLogitProcessor(int max_num_token,
Optional<EventTraceRecorder> trace_recorder) = 0;

/*! \brief Create a sampler from this model. */
virtual Sampler CreateSampler(int max_num_sample,
Optional<EventTraceRecorder> trace_recorder) = 0;

/*!
* \brief Estimate number of CPU units required to drive the model
* executing during TP.
Expand Down
24 changes: 8 additions & 16 deletions cpp/serve/sampler.cc → cpp/serve/sampler/cpu_sampler.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
/*!
* Copyright (c) 2023 by Contributors
* \file serve/sampler.cc
* \brief The implementation for runtime module of sampler functions.
* \file serve/sampler/cpu_sampler.cc
* \brief The implementation for CPU sampler functions.
*/
#include "sampler.h"

#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>

#include <cmath>

#include "../random.h"
#include "../../random.h"
#include "sampler.h"

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -250,6 +249,8 @@ inline std::vector<TokenProbPair> ComputeTopProbs(NDArray prob, int unit_offset,

/********************* CPU Sampler *********************/

TVM_REGISTER_OBJECT_TYPE(SamplerObj);

class CPUSampler : public SamplerObj {
public:
explicit CPUSampler(Optional<EventTraceRecorder> trace_recorder)
Expand Down Expand Up @@ -430,17 +431,8 @@ class CPUSampler : public SamplerObj {
const float eps_ = 1e-5;
};

/*********************** Sampler ***********************/

TVM_REGISTER_OBJECT_TYPE(SamplerObj);

Sampler Sampler::Create(std::string sampler_kind, Optional<EventTraceRecorder> trace_recorder) {
if (sampler_kind == "cpu") {
return Sampler(make_object<CPUSampler>(std::move(trace_recorder)));
} else {
LOG(FATAL) << "Unsupported sampler_kind \"" << sampler_kind << "\"";
throw;
}
Sampler Sampler::CreateCPUSampler(Optional<EventTraceRecorder> trace_recorder) {
return Sampler(make_object<CPUSampler>(std::move(trace_recorder)));
}

} // namespace serve
Expand Down

0 comments on commit 808da17

Please sign in to comment.