diff --git a/build.py b/build.py index 3fdddad5a4..be73faaf31 100644 --- a/build.py +++ b/build.py @@ -234,6 +234,7 @@ def dump_default_mlc_chat_config(args): config["local_id"] = f"{args.model}-{args.quantization.name}" config["conv_template"] = args.conv_template config["temperature"] = 0.7 + config["repetition_penalty"] = 1.0 config["top_p"] = 0.95 config["mean_gen_len"] = 128 config["shift_fill_factor"] = 0.3 diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f17bf33ecc..c0ab5a4b0e 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -24,6 +24,7 @@ #include #include #include +#include namespace mlc { namespace llm { @@ -531,11 +532,14 @@ class LLMChat { auto config = config_info.get(); ICHECK(config["conv_template"].is()); ICHECK(config["temperature"].is()); + ICHECK(config["repetition_penalty"].is()); ICHECK(config["top_p"].is()); ICHECK(config["mean_gen_len"].is()); ICHECK(config["shift_fill_factor"].is()); std::string conv_template = config["conv_template"].get(); this->temperature_ = config["temperature"].get(); + this->repetition_penalty_ = config["repetition_penalty"].get(); + CHECK(this->repetition_penalty_ > 0) << "Repetition penalty must be a positive number!"; this->top_p_ = config["top_p"].get(); this->mean_gen_len_ = config["mean_gen_len"].get(); this->shift_fill_factor_ = config["shift_fill_factor"].get(); @@ -718,6 +722,7 @@ class LLMChat { this->ResetRuntimeStats(); } output_ids_.clear(); + appeared_token_ids_.clear(); output_message_.clear(); encounter_stop_str_ = false; @@ -757,6 +762,7 @@ class LLMChat { void DecodeStep() { output_ids_.push_back(next_token_); + appeared_token_ids_.insert(next_token_); output_message_ = RemoveStopStr(tokenizer_->Decode(output_ids_)); tvm::runtime::NDArray input_data = GetInputTokenNDArray({next_token_}); @@ -765,11 +771,19 @@ class LLMChat { cur_pos_ += 1; auto tstart = std::chrono::high_resolution_clock::now(); - if (temperature_ < 1e-6f) { - this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_)); + if (repetition_penalty_ == 1.0f) { + if (temperature_ < 1e-6f) { + this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_)); + } else { + this->UpdateLogitsOrProbOnCPU( + this->Softmax(this->Forward(input_data, total_seq_len_), temperature_)); + } } else { - this->UpdateLogitsOrProbOnCPU( - this->Softmax(this->Forward(input_data, total_seq_len_), temperature_)); + this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_)); + this->ApplyRepetitionPenaltyOnCPU(); + if (temperature_ >= 1e-6f) { + this->ApplySoftmaxWithTemperatureOnCPU(); + } } TVMSynchronize(device_.device_type, device_.device_id, nullptr); auto tsample_start = std::chrono::high_resolution_clock::now(); @@ -908,6 +922,39 @@ class LLMChat { return ret; } + void ApplyRepetitionPenaltyOnCPU() { + CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; + CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + float* logits_raw_data = static_cast(logits_on_cpu_->data); + for (const int32_t& token_id : this->appeared_token_ids_) { + if (logits_raw_data[token_id] <= 0) { + logits_raw_data[token_id] *= this->repetition_penalty_; + } else { // logits > 0 + logits_raw_data[token_id] /= this->repetition_penalty_; + } + } + } + + void ApplySoftmaxWithTemperatureOnCPU() { + CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; + CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + int vocab_size = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; + float* logits_raw_data = static_cast(logits_on_cpu_->data); + float m = std::numeric_limits::min(); + float inv_temp = 1.0f / this->temperature_; + double d = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + float x = logits_raw_data[i] * inv_temp; + float m_prev = m; + m = std::max(m, x); + d = d * std::exp(m_prev - m) + std::exp(x - m); + } + for (int i = 0; i < vocab_size; ++i) { + float x = logits_raw_data[i] * inv_temp; + logits_raw_data[i] = std::exp(x - m) / d; + } + } + void UpdateLogitsOrProbOnCPU(NDArray logits_or_prob) { if (!logits_on_cpu_.defined()) { logits_on_cpu_ = logits_or_prob.CopyTo(DLDevice{kDLCPU, 0}); @@ -983,12 +1030,16 @@ class LLMChat { double shift_fill_factor_{0.3}; // temperature double temperature_{0.8}; + // repetition penalty + double repetition_penalty_{1.0}; // top_p double top_p_{0.95}; // next_token int32_t next_token_{0}; // output ids till now (refresh after encoding step) std::vector output_ids_; + // appeared token ids till now (refresh after encoding step) + std::unordered_set appeared_token_ids_; // output message till now (refresh after encoding step) std::string output_message_; // whether to add bos as the first token