Skip to content

Commit

Permalink
[Decoding] Support repetition penalty (octoml#177)
Browse files Browse the repository at this point in the history
The repetition penalty (introduced in [CTRL](https://arxiv.org/abs/1909.05858)) can help prevent the LLM from generating repetitive tokens.
This PR implements the repetition penalty.

Note: Previous the logits softmax is performed on GPU, this PR moves it to CPU to accommodate the repetition penalty.
  • Loading branch information
yzh119 committed May 18, 2023
1 parent 615020d commit ff81bdb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
1 change: 1 addition & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def dump_default_mlc_chat_config(args):
config["local_id"] = f"{args.model}-{args.quantization.name}"
config["conv_template"] = args.conv_template
config["temperature"] = 0.7
config["repetition_penalty"] = 1.0
config["top_p"] = 0.95
config["mean_gen_len"] = 128
config["shift_fill_factor"] = 0.3
Expand Down
59 changes: 55 additions & 4 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <optional>
#include <random>
#include <string>
#include <unordered_set>

namespace mlc {
namespace llm {
Expand Down Expand Up @@ -531,11 +532,14 @@ class LLMChat {
auto config = config_info.get<picojson::object>();
ICHECK(config["conv_template"].is<std::string>());
ICHECK(config["temperature"].is<double>());
ICHECK(config["repetition_penalty"].is<double>());
ICHECK(config["top_p"].is<double>());
ICHECK(config["mean_gen_len"].is<int64_t>());
ICHECK(config["shift_fill_factor"].is<double>());
std::string conv_template = config["conv_template"].get<std::string>();
this->temperature_ = config["temperature"].get<double>();
this->repetition_penalty_ = config["repetition_penalty"].get<double>();
CHECK(this->repetition_penalty_ > 0) << "Repetition penalty must be a positive number!";
this->top_p_ = config["top_p"].get<double>();
this->mean_gen_len_ = config["mean_gen_len"].get<int64_t>();
this->shift_fill_factor_ = config["shift_fill_factor"].get<double>();
Expand Down Expand Up @@ -718,6 +722,7 @@ class LLMChat {
this->ResetRuntimeStats();
}
output_ids_.clear();
appeared_token_ids_.clear();
output_message_.clear();
encounter_stop_str_ = false;

Expand Down Expand Up @@ -757,6 +762,7 @@ class LLMChat {

void DecodeStep() {
output_ids_.push_back(next_token_);
appeared_token_ids_.insert(next_token_);
output_message_ = RemoveStopStr(tokenizer_->Decode(output_ids_));

tvm::runtime::NDArray input_data = GetInputTokenNDArray({next_token_});
Expand All @@ -765,11 +771,19 @@ class LLMChat {
cur_pos_ += 1;

auto tstart = std::chrono::high_resolution_clock::now();
if (temperature_ < 1e-6f) {
this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_));
if (repetition_penalty_ == 1.0f) {
if (temperature_ < 1e-6f) {
this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_));
} else {
this->UpdateLogitsOrProbOnCPU(
this->Softmax(this->Forward(input_data, total_seq_len_), temperature_));
}
} else {
this->UpdateLogitsOrProbOnCPU(
this->Softmax(this->Forward(input_data, total_seq_len_), temperature_));
this->UpdateLogitsOrProbOnCPU(this->Forward(input_data, total_seq_len_));
this->ApplyRepetitionPenaltyOnCPU();
if (temperature_ >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU();
}
}
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
auto tsample_start = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -908,6 +922,39 @@ class LLMChat {
return ret;
}

void ApplyRepetitionPenaltyOnCPU() {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
for (const int32_t& token_id : this->appeared_token_ids_) {
if (logits_raw_data[token_id] <= 0) {
logits_raw_data[token_id] *= this->repetition_penalty_;
} else { // logits > 0
logits_raw_data[token_id] /= this->repetition_penalty_;
}
}
}

void ApplySoftmaxWithTemperatureOnCPU() {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
int vocab_size = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1];
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
float m = std::numeric_limits<float>::min();
float inv_temp = 1.0f / this->temperature_;
double d = 0.0f;
for (int i = 0; i < vocab_size; ++i) {
float x = logits_raw_data[i] * inv_temp;
float m_prev = m;
m = std::max(m, x);
d = d * std::exp(m_prev - m) + std::exp(x - m);
}
for (int i = 0; i < vocab_size; ++i) {
float x = logits_raw_data[i] * inv_temp;
logits_raw_data[i] = std::exp(x - m) / d;
}
}

void UpdateLogitsOrProbOnCPU(NDArray logits_or_prob) {
if (!logits_on_cpu_.defined()) {
logits_on_cpu_ = logits_or_prob.CopyTo(DLDevice{kDLCPU, 0});
Expand Down Expand Up @@ -983,12 +1030,16 @@ class LLMChat {
double shift_fill_factor_{0.3};
// temperature
double temperature_{0.8};
// repetition penalty
double repetition_penalty_{1.0};
// top_p
double top_p_{0.95};
// next_token
int32_t next_token_{0};
// output ids till now (refresh after encoding step)
std::vector<int32_t> output_ids_;
// appeared token ids till now (refresh after encoding step)
std::unordered_set<int32_t> appeared_token_ids_;
// output message till now (refresh after encoding step)
std::string output_message_;
// whether to add bos as the first token
Expand Down

0 comments on commit ff81bdb

Please sign in to comment.