Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gptneox & gptj int8 quantization & share context #653

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cpp/gptj/gptj_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enable_custom_all_reduce=0

tensor_para_size=1
pipeline_para_size=1
int8_mode=0 ;only support 0 or 1 (when fp16)

model_name=gptj_6B
model_dir=../models/j6b_ckpt/
Expand Down
8 changes: 7 additions & 1 deletion examples/cpp/gptj/gptj_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void gptj_example(const INIReader reader)

int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size");
int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size");
int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0);

const size_t head_num = reader.GetInteger(model_name, "head_num");
const size_t size_per_head = reader.GetInteger(model_name, "size_per_head");
Expand Down Expand Up @@ -287,6 +288,7 @@ void gptj_example(const INIReader reader)
tensor_para.rank_,
pipeline_para.world_size_,
pipeline_para.rank_,
int8_mode,
prompt_learning_type,
prefix_prompt_table_pair); // optional if you don't need prefix prompts

Expand Down Expand Up @@ -336,7 +338,11 @@ void gptj_example(const INIReader reader)
&allocator,
false,
&prop,
attention_type);
attention_type,
int8_mode,
nullptr,
0,
1.0f);

int* d_output_ids;
int* d_sequence_lengths;
Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/gptneox/gptneox_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ enable_custom_all_reduce=0

tensor_para_size=2
pipeline_para_size=1

int8_mode=0 ;only support 0 or 1 (when fp16)
model_name=gptneox_20B
model_dir=../models/gptneox

Expand Down
10 changes: 9 additions & 1 deletion examples/cpp/gptneox/gptneox_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ int main(int argc, char* argv[])
ini_name = "../examples/cpp/gptneox/gptneox_config.ini";
}

std::cout << "Ini file name: " << ini_name << std::endl;

INIReader reader = INIReader(ini_name);
if (reader.ParseError() < 0) {
std::cout << "[ERROR] Can't load '" << ini_name << "'\n";
Expand Down Expand Up @@ -76,6 +78,7 @@ void gptneox_example(const INIReader reader)

int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size");
int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size");
int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0);

const size_t head_num = reader.GetInteger(model_name, "head_num");
const size_t size_per_head = reader.GetInteger(model_name, "size_per_head");
Expand Down Expand Up @@ -275,6 +278,7 @@ void gptneox_example(const INIReader reader)
pipeline_para.world_size_,
pipeline_para.rank_,
use_gptj_residual,
int8_mode,
prompt_learning_type,
prefix_prompt_table_pair);

Expand Down Expand Up @@ -321,7 +325,11 @@ void gptneox_example(const INIReader reader)
&allocator,
false,
&prop,
attention_type);
attention_type,
int8_mode,
nullptr,
0,
1.0f);

int* d_output_ids;
int* d_sequence_lengths;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
#include <float.h>
#include <type_traits>

// #define MMHA_USE_HMMA_FOR_REDUCTION
#define MMHA_USE_HMMA_FOR_REDUCTION

// Below are knobs to extend FP32 accumulation for higher FP16 accuracy

// Does not seem to affect the accuracy that much
// #define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_FMA

// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
Expand Down Expand Up @@ -389,26 +389,6 @@ struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};

template<>
struct Qk_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#ifdef ENABLE_FP8
// template<>
// struct Qk_vec_acum_fp32_<fp8_2_t> {
Expand Down
6 changes: 3 additions & 3 deletions src/fastertransformer/kernels/decoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,9 @@ __global__ void gatherTree(gatherTreeParam param)
int tmp_len =
param.max_sequence_lengths[batch * param.beam_width + j] + param.max_sequence_length_final_step;
// also remove the length of the soft prompts, p_prompt_tuning
param.max_sequence_lengths[batch * param.beam_width + j] =
tmp_len - param.max_prefix_soft_prompt_length
- (param.max_input_length - param.max_input_without_prompt_length);
param.sequence_lengths_for_output[batch * param.beam_width + j] =
uint32_t(tmp_len - param.max_prefix_soft_prompt_length
- (param.max_input_length - param.max_input_without_prompt_length));
// update the response input length
if (update_response_input_length) {
param.response_input_lengths[batch * param.beam_width + j] = input_len - prompt_len;
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/kernels/decoding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ void invokeGatherTree(int* beams,
struct gatherTreeParam {
int* beams = nullptr;
int* max_sequence_lengths = nullptr;
uint* sequence_lengths_for_output = nullptr;
int max_sequence_length_final_step = 0;
const int* input_lengths = nullptr;
// response input lengths (used to slice the ids during postprocessing)
Expand Down
59 changes: 54 additions & 5 deletions src/fastertransformer/models/gptj/GptJ.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void GptJ<T>::initialize()
is_free_buffer_after_forward_,
is_context_qk_buf_float_,
attention_type_,
int8_mode_,
custom_all_reduce_comm_,
enable_custom_all_reduce_);

Expand All @@ -60,6 +61,7 @@ void GptJ<T>::initialize()
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
int8_mode_,
custom_all_reduce_comm_,
enable_custom_all_reduce_);

Expand Down Expand Up @@ -153,6 +155,13 @@ void GptJ<T>::allocateBuffer(
(float*)(allocator_->reMalloc(output_log_probs_buf_, sizeof(float) * batchxbeam * max_seq_len, false));
generation_should_stop_ = (bool*)(allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true));

if (shared_contexts_ratio_ > 0.0f) {
shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, batch_size * sizeof(int), false);
batch_to_compact_idx_ = (int*)allocator_->reMalloc(batch_to_compact_idx_, batchxbeam * sizeof(int), false);
compact_idx_ = (int*)allocator_->reMalloc(compact_idx_, batch_size * sizeof(int), false);
compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false);
}

is_allocate_buffer_ = true;
}

Expand Down Expand Up @@ -205,6 +214,11 @@ void GptJ<T>::freeBuffer()

allocator_->free((void**)(&generation_should_stop_), true);

if (shared_contexts_ratio_ > 0.0f) {
allocator_->free((void**)(&shared_contexts_idx_));
allocator_->free((void**)(&compact_size_));
}

is_allocate_buffer_ = false;
}
}
Expand Down Expand Up @@ -237,8 +251,10 @@ GptJ<T>::GptJ(size_t max_batch_size,
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop,
AttentionType attention_type,
int int8_mode,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm,
int enable_custom_all_reduce):
int enable_custom_all_reduce,
float shared_contexts_ratio):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop),
head_num_(head_num),
size_per_head_(size_per_head),
Expand All @@ -252,7 +268,9 @@ GptJ<T>::GptJ(size_t max_batch_size,
prompt_learning_type_(prompt_learning_type),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / 1),
attention_type_(attention_type)
attention_type_(attention_type),
int8_mode_(int8_mode),
shared_contexts_ratio_(shared_contexts_ratio)
{
tensor_para_.world_size_ = 1;
tensor_para_.rank_ = 0;
Expand Down Expand Up @@ -297,8 +315,10 @@ GptJ<T>::GptJ(size_t max_batch_size,
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop,
AttentionType attention_type,
int int8_mode,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm,
int enable_custom_all_reduce):
int enable_custom_all_reduce,
float shared_contexts_ratio):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop),
head_num_(head_num),
size_per_head_(size_per_head),
Expand All @@ -315,8 +335,10 @@ GptJ<T>::GptJ(size_t max_batch_size,
pipeline_para_(pipeline_para),
local_head_num_(head_num / tensor_para.world_size_),
attention_type_(attention_type),
int8_mode_(int8_mode),
custom_all_reduce_comm_(custom_all_reduce_comm),
enable_custom_all_reduce_(enable_custom_all_reduce)
enable_custom_all_reduce_(enable_custom_all_reduce),
shared_contexts_ratio_(shared_contexts_ratio)
{
int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_);
if (std::is_same<half, T>::value) {
Expand Down Expand Up @@ -345,8 +367,10 @@ GptJ<T>::GptJ(GptJ<T> const& gpt):
local_head_num_(gpt.local_head_num_),
vocab_size_padded_(gpt.vocab_size_padded_),
attention_type_(gpt.attention_type_),
int8_mode_(gpt.int8_mode_),
custom_all_reduce_comm_(gpt.custom_all_reduce_comm_),
enable_custom_all_reduce_(gpt.enable_custom_all_reduce_)
enable_custom_all_reduce_(gpt.enable_custom_all_reduce_),
shared_contexts_ratio_(gpt.shared_contexts_ratio_)
{
initialize();
}
Expand Down Expand Up @@ -584,6 +608,23 @@ void GptJ<T>::forward(std::unordered_map<std::string, Tensor>* output_tens
cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_);
}

int compact_size;
bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1);
if (use_shared_contexts) {
invokeFindContextDups(shared_contexts_idx_,
batch_to_compact_idx_,
compact_idx_,
compact_size_,
input_tensors->at("input_ids").getPtr<int>(),
batch_size,
beam_width,
max_input_length,
stream_);
cudaD2Hcpy(&compact_size, compact_size_, 1);
use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size;
sync_check_cuda_error();
}

// Prefix prompts
if (has_prefix_prompt_) {
cudaMemcpyAsync(prompt_learning_weight_batch_,
Expand Down Expand Up @@ -685,6 +726,14 @@ void GptJ<T>::forward(std::unordered_map<std::string, Tensor>* output_tens
{batch_size * beam_width},
has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : nullptr}}};

if (use_shared_contexts) {
decoder_input_tensors.insert(
{"compact_idx", Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_)});
decoder_input_tensors.insert(
{"batch_to_compact_idx",
Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)});
}

std::unordered_map<std::string, Tensor> decoder_output_tensors{
{"decoder_output",
Tensor{MEMORY_GPU,
Expand Down
16 changes: 14 additions & 2 deletions src/fastertransformer/models/gptj/GptJ.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ class GptJ: public BaseLayer {
int enable_custom_all_reduce_;

AttentionType attention_type_;
const int int8_mode_ = 0;

size_t vocab_size_padded_;
const bool is_context_qk_buf_float_ =
(std::getenv("CONTEXT_ATTENTION_BMM1_HALF_ACCUM") == nullptr ||
std::string(std::getenv("CONTEXT_ATTENTION_BMM1_HALF_ACCUM")) != "ON");

float shared_contexts_ratio_;

// Prompt Learning Parameters
PromptLearningType prompt_learning_type_;
Expand Down Expand Up @@ -117,6 +120,11 @@ class GptJ: public BaseLayer {

bool* generation_should_stop_ = nullptr;

int* shared_contexts_idx_ = nullptr;
int* compact_idx_ = nullptr;
int* batch_to_compact_idx_ = nullptr;
int* compact_size_ = nullptr;

T* context_decoder_input_buf_;
T* context_decoder_output_buf_;
float* output_log_probs_buf_;
Expand Down Expand Up @@ -161,8 +169,10 @@ class GptJ: public BaseLayer {
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr,
AttentionType attention_type = AttentionType::UNFUSED_MHA,
int int8_mode = 0,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm = nullptr,
int enable_custom_all_reduce = 0);
int enable_custom_all_reduce = 0,
float shared_contexts_ratio = 1.0f);

GptJ(size_t max_batch_size,
size_t max_seq_len,
Expand Down Expand Up @@ -193,8 +203,10 @@ class GptJ: public BaseLayer {
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr,
AttentionType attention_type = AttentionType::UNFUSED_MHA,
int int8_mode = 0,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm = nullptr,
int enable_custom_all_reduce = 0);
int enable_custom_all_reduce = 0,
float shared_contexts_ratio = 1.0f);

GptJ(GptJ<T> const& GptJ);

Expand Down
Loading